| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | /*************************************************************************************************** | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /*! \file | 
					
						
							|  |  |  |     \brief CUTLASS Library handle. | 
					
						
							|  |  |  | */ | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | #include <iostream>  | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | #include <stdexcept> | 
					
						
							|  |  |  | #include <cstdint> | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/library/handle.h" | 
					
						
							|  |  |  | #include "cutlass/library/singleton.h" | 
					
						
							|  |  |  | #include "cutlass/library/util.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace library { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Constructor | 
					
						
							|  |  |  | Handle::Handle( | 
					
						
							|  |  |  |   cudaStream_t stream,  | 
					
						
							|  |  |  |   size_t workspace_size | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | ): | 
					
						
							|  |  |  |   provider_(Provider::kCUTLASS),  | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   stream_(stream),  | 
					
						
							|  |  |  |   workspace_(nullptr),  | 
					
						
							|  |  |  |   workspace_size_(0),  | 
					
						
							|  |  |  |   scalar_pointer_mode_(ScalarPointerMode::kHost),  | 
					
						
							|  |  |  |   last_operation_(nullptr) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int device_idx = -1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cudaError_t error = cudaGetDevice(&device_idx); | 
					
						
							|  |  |  |   if (error != cudaSuccess) { | 
					
						
							|  |  |  |     throw std::runtime_error("cudaGetDevice() failed"); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   error = cudaGetDeviceProperties(&device_, device_idx); | 
					
						
							|  |  |  |   if (error != cudaSuccess) { | 
					
						
							|  |  |  |     throw std::runtime_error("cudaGetDeviceProperties() failed"); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   set_workspace_size(workspace_size); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Singleton::get(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Destructor | 
					
						
							|  |  |  | Handle::~Handle() { | 
					
						
							|  |  |  |   if (workspace_) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (workspace_) { | 
					
						
							|  |  |  |       cudaFree(workspace_); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     workspace_ = nullptr; | 
					
						
							|  |  |  |     workspace_size_ = 0; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Move constructor | 
					
						
							|  |  |  | Handle::Handle(Handle && handle) { | 
					
						
							|  |  |  |   device_ = handle.device_; | 
					
						
							|  |  |  |   workspace_size_ = handle.workspace_size_; | 
					
						
							|  |  |  |   workspace_ = handle.workspace_; | 
					
						
							|  |  |  |   stream_ = handle.stream_; | 
					
						
							|  |  |  |   scalar_pointer_mode_ = handle.scalar_pointer_mode_; | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   handle.workspace_ = nullptr; | 
					
						
							|  |  |  |   handle.workspace_size_ = 0; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Move assignment operator | 
					
						
							|  |  |  | Handle & Handle::operator=(Handle && handle) { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   provider_ = handle.provider_; | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   device_ = handle.device_; | 
					
						
							|  |  |  |   workspace_size_ = handle.workspace_size_; | 
					
						
							|  |  |  |   workspace_ = handle.workspace_; | 
					
						
							|  |  |  |   stream_ = handle.stream_; | 
					
						
							|  |  |  |   scalar_pointer_mode_ = handle.scalar_pointer_mode_; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   handle.workspace_ = nullptr; | 
					
						
							|  |  |  |   handle.workspace_size_ = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return *this; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int Handle::compute_capability() const { | 
					
						
							|  |  |  |   return device_.major * 10 + device_.minor; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Sets the current CUDA stream | 
					
						
							|  |  |  | void Handle::set_stream(cudaStream_t stream) { | 
					
						
							|  |  |  |   stream_ = stream; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Gets the current CUDA stream | 
					
						
							|  |  |  | cudaStream_t Handle::get_stream() const { | 
					
						
							|  |  |  |   return stream_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | /// Gets the current provider | 
					
						
							|  |  |  | Provider Handle::get_provider() const { | 
					
						
							|  |  |  |   return provider_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Sets the provider of operations | 
					
						
							|  |  |  | void Handle::set_provider(Provider provider) { | 
					
						
							|  |  |  |   provider_ = provider; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | /// Gets the device workspace size | 
					
						
							|  |  |  | size_t Handle::get_workspace_size() const { | 
					
						
							|  |  |  |   return workspace_size_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Gets a pointer to the device workspace allocation in Global Memory | 
					
						
							|  |  |  | void *Handle::get_workspace() const { | 
					
						
							|  |  |  |   return workspace_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Sets the size of device workspace, invalidating previous calls to get_device_workspace() | 
					
						
							|  |  |  | void Handle::set_workspace_size(size_t bytes) { | 
					
						
							|  |  |  |   if (bytes != workspace_size_) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (workspace_) { | 
					
						
							|  |  |  |       cudaFree(workspace_); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |     workspace_ = nullptr; | 
					
						
							|  |  |  |     workspace_size_ = bytes; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (workspace_size_) { | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |       cudaError_t error = cudaMalloc((void **)&workspace_, workspace_size_); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |       if (error != cudaSuccess) { | 
					
						
							|  |  |  |         throw std::runtime_error("Failed to allocate workspace"); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (workspace_) { | 
					
						
							|  |  |  |     cudaError_t error = cudaMemset(workspace_, 0, workspace_size_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (error != cudaSuccess) { | 
					
						
							|  |  |  |       throw std::runtime_error("Failed to clear workspace"); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Gets the scalar pointer mode | 
					
						
							|  |  |  | ScalarPointerMode Handle::get_scalar_pointer_mode() const { | 
					
						
							|  |  |  |   return scalar_pointer_mode_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Sets the scalar pointer mode | 
					
						
							|  |  |  | void Handle::set_scalar_pointer_mode(ScalarPointerMode mode) { | 
					
						
							|  |  |  |   scalar_pointer_mode_ = mode; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Gets the last operation | 
					
						
							|  |  |  | Operation const *Handle::get_last_operation() const { | 
					
						
							|  |  |  |   return last_operation_; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Returns the maximum required alignment for each operator | 
					
						
							|  |  |  | static int maximum_alignment_requirement(GemmDescription const &desc) { | 
					
						
							|  |  |  |   return std::max( | 
					
						
							|  |  |  |     std::max(desc.A.alignment, desc.B.alignment), desc.C.alignment); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Returns the largest alignment (in units of elements) the problem satisfies, starting from a | 
					
						
							|  |  |  | /// given upper limit. | 
					
						
							|  |  |  | static int gemm_problem_alignment( | 
					
						
							|  |  |  |   int M, | 
					
						
							|  |  |  |   int N, | 
					
						
							|  |  |  |   int K, | 
					
						
							|  |  |  |   NumericTypeID element_A, | 
					
						
							|  |  |  |   void const *ptr_A, | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t lda, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   int64_t batch_stride_A, | 
					
						
							|  |  |  |   NumericTypeID element_B, | 
					
						
							|  |  |  |   void const *ptr_B, | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldb, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   int64_t batch_stride_B, | 
					
						
							|  |  |  |   NumericTypeID element_C, | 
					
						
							|  |  |  |   void const * ptr_C, | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldc, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   int64_t batch_stride_C, | 
					
						
							|  |  |  |   void const * ptr_D, | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldd, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |   int64_t batch_stride_D, | 
					
						
							|  |  |  |   int max_alignment_in_bytes = 16 | 
					
						
							|  |  |  | ) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *pointers[] = { | 
					
						
							|  |  |  |     ptr_A, ptr_B, ptr_C, ptr_D | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t extents[] = { | 
					
						
							|  |  |  |     M, N, K, lda, ldb, ldc, ldd, batch_stride_A, batch_stride_B, batch_stride_C, batch_stride_D | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID elements[] = { | 
					
						
							|  |  |  |     element_A, element_B, element_C | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   for (; max_alignment_in_bytes > 0; max_alignment_in_bytes /= 2) { | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     bool satisfied = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Can pointers satisfy this? | 
					
						
							|  |  |  |     for (void const *ptr : pointers) { | 
					
						
							|  |  |  |       std::uintptr_t int_ptr = reinterpret_cast<std::uintptr_t>(ptr); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (int_ptr % max_alignment_in_bytes) { | 
					
						
							|  |  |  |         satisfied = false; | 
					
						
							|  |  |  |         break; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (!satisfied) { | 
					
						
							|  |  |  |       continue; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Compute the maximum alignment based on element data types | 
					
						
							|  |  |  |     int max_element_alignment = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (NumericTypeID type_id : elements) { | 
					
						
							|  |  |  |       int element_alignment = max_alignment_in_bytes * 8 / library::sizeof_bits(type_id);  | 
					
						
							|  |  |  |       max_element_alignment = std::max(max_element_alignment, element_alignment); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Can the problem size and leading dimensions satisfy this? | 
					
						
							|  |  |  |     for (int64_t extent : extents) { | 
					
						
							|  |  |  |       if (extent % max_element_alignment) { | 
					
						
							|  |  |  |         satisfied = false; | 
					
						
							|  |  |  |         break; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (!satisfied) { | 
					
						
							|  |  |  |       continue; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Yes | 
					
						
							|  |  |  |     return max_element_alignment; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // No alignment satisfies this problem | 
					
						
							|  |  |  |   return 0; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Find the best kernel in descending order of preference. | 
					
						
							|  |  |  | static Operation const * find_gemm_operation( | 
					
						
							|  |  |  |   GemmOperationFunctionalMap::const_iterator operators_it,  | 
					
						
							|  |  |  |   GemmPreferenceKey const preference_key) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto cc_it = operators_it->second.upper_bound(preference_key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (cc_it == operators_it->second.begin()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Operation const *operation = nullptr; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Search in descending order of compute capability | 
					
						
							|  |  |  |   do { | 
					
						
							|  |  |  |     --cc_it; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Search tile sizes in order, for now. | 
					
						
							|  |  |  |     for (auto const * op : cc_it->second) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       GemmDescription const &desc = static_cast<GemmDescription const &>(op->description()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int min_cc = desc.tile_description.minimum_compute_capability; | 
					
						
							|  |  |  |       int max_cc = desc.tile_description.maximum_compute_capability; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int op_alignment = maximum_alignment_requirement(desc); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if ((min_cc <= preference_key.compute_capability) && | 
					
						
							|  |  |  |         (preference_key.compute_capability <= max_cc) && | 
					
						
							|  |  |  |         (op_alignment <= preference_key.alignment)) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         operation = op; | 
					
						
							|  |  |  |         break; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } while (!operation && cc_it != operators_it->second.begin()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return operation; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Executes a GEMM computation: D <= alpha * A*B + beta * C | 
					
						
							|  |  |  | Status Handle::gemm( | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int M,                                    /// GEMM M dimension | 
					
						
							|  |  |  |   int N,                                    /// GEMM N dimension | 
					
						
							|  |  |  |   int K,                                    /// GEMM K dimension | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_compute,            /// Data type of internal accumulation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_scalar,             /// Data type of alpha/beta scalars | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *alpha,                        /// Pointer to alpha scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_A,                  /// Data type of A matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_A,                    /// Layout of A matrix | 
					
						
							|  |  |  |   ComplexTransform transform_A,             /// Complex transformation applied to A matrix - ignored for real-valued matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_A,                       /// Pointer to A matrix in Global Memory | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t lda,                              /// Leading dimension of A matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_B,                  /// Data type of B matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_B,                    /// Layout of B matrix | 
					
						
							|  |  |  |   ComplexTransform transform_B,             /// Complex transformation applied to B matrix - ignored for real-valued matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_B,                       /// Pointer to B matrix in Global Memory | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldb,                              /// Leading dimension of B matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void const * beta,                        /// Pointer to beta scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_C,                  /// Data type of C and D matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_C,                       /// Pointer to C matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldc,                              /// Leading dimension of C matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void * ptr_D,                             /// Pointer to D matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldd                               /// Leading dimension of D matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | ) { | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmFunctionalKey key( | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |     provider_, | 
					
						
							|  |  |  |     GemmKind::kGemm, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     element_compute, | 
					
						
							|  |  |  |     element_scalar, | 
					
						
							|  |  |  |     element_A, | 
					
						
							|  |  |  |     layout_A, | 
					
						
							|  |  |  |     transform_A, | 
					
						
							|  |  |  |     element_B, | 
					
						
							|  |  |  |     layout_B, | 
					
						
							|  |  |  |     transform_B, | 
					
						
							|  |  |  |     element_C | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Compute the largest alignment restriction the kernel can satisfy. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Maximum alignment expectation among all kernels (in units of bytes) | 
					
						
							|  |  |  |   int const kMaximumAlignmentSize = 16; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int alignment = gemm_problem_alignment( | 
					
						
							|  |  |  |     M, N, K,  | 
					
						
							|  |  |  |     element_A, ptr_A, lda, 0, | 
					
						
							|  |  |  |     element_B, ptr_B, ldb, 0, | 
					
						
							|  |  |  |     element_C, ptr_C, ldc, 0, | 
					
						
							|  |  |  |     ptr_D, ldd, 0, kMaximumAlignmentSize | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the best kernel in descending order of preference. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPreferenceKey preference_key(compute_capability(), alignment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Operation const *operation = find_gemm_operation(operators_it, preference_key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (!operation) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   last_operation_ = operation; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Configure operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmConfiguration configuration{ | 
					
						
							|  |  |  |     {M, N, K}, | 
					
						
							|  |  |  |     lda, | 
					
						
							|  |  |  |     ldb, | 
					
						
							|  |  |  |     ldc, | 
					
						
							|  |  |  |     ldd, | 
					
						
							|  |  |  |     1 | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query host work space size | 
					
						
							|  |  |  |   uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   char host_workspace[kHostWorkspaceSize]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query device workspace size | 
					
						
							|  |  |  |   uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(workspace_size_) < device_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Initialize host and device workspaces | 
					
						
							|  |  |  |   Status status = operation->initialize( | 
					
						
							|  |  |  |     &configuration, | 
					
						
							|  |  |  |     host_workspace, | 
					
						
							|  |  |  |     workspace_, | 
					
						
							|  |  |  |     stream_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Run the operator | 
					
						
							|  |  |  |   GemmArguments arguments{ | 
					
						
							|  |  |  |     ptr_A, | 
					
						
							|  |  |  |     ptr_B, | 
					
						
							|  |  |  |     ptr_C, | 
					
						
							|  |  |  |     ptr_D, | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     scalar_pointer_mode_ | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return operation->run(&arguments, host_workspace, workspace_, stream_); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | /// Executes a GEMM computation: D <= alpha * A*B + beta * C. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // Supports batched-strided, batched array or split-K serial or split-K parallel. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | Status Handle::gemm_universal( | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmUniversalMode mode,                   /// indicates the mode in which the kUniversal GEMM is launched | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int M,                                    /// GEMM M dimension | 
					
						
							|  |  |  |   int N,                                    /// GEMM N dimension | 
					
						
							|  |  |  |   int K,                                    /// GEMM K dimension | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_compute,            /// Data type of internal accumulation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_scalar,             /// Data type of alpha/beta scalars | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *alpha,                        /// Pointer to alpha scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_A,                  /// Data type of A matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_A,                    /// Layout of A matrix | 
					
						
							|  |  |  |   ComplexTransform transform_A,             /// Complex transformation applied to A matrix - ignored for real-valued matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_A,                       /// Pointer to A matrix in Global Memory | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t lda,                                  /// Leading dimension of A matrix | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_B,                  /// Data type of B matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_B,                    /// Layout of B matrix | 
					
						
							|  |  |  |   ComplexTransform transform_B,             /// Complex transformation applied to B matrix - ignored for real-valued matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_B,                       /// Pointer to B matrix in Global Memory | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldb,                                  /// Leading dimension of B matrix | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void const * beta,                        /// Pointer to beta scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_C,                  /// Data type of C and D matrices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_C,                       /// Pointer to C matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldc,                                  /// Leading dimension of C matrix | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void * ptr_D,                             /// Pointer to D matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldd,                                  /// Leading dimension of D matrix | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   int batch_count,                          /// Batch count or number of split-K slices | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t batch_stride_A,                   /// Batch stride of A operand | 
					
						
							|  |  |  |   int64_t batch_stride_B,                   /// Batch stride of B operand | 
					
						
							|  |  |  |   int64_t batch_stride_C,                   /// Batch stride of C operand | 
					
						
							|  |  |  |   int64_t batch_stride_D                    /// Batch stride of D operand | 
					
						
							|  |  |  | ) { | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmFunctionalKey key( | 
					
						
							|  |  |  |     provider_, | 
					
						
							|  |  |  |     GemmKind::kUniversal, | 
					
						
							|  |  |  |     element_compute, | 
					
						
							|  |  |  |     element_scalar, | 
					
						
							|  |  |  |     element_A, | 
					
						
							|  |  |  |     layout_A, | 
					
						
							|  |  |  |     transform_A, | 
					
						
							|  |  |  |     element_B, | 
					
						
							|  |  |  |     layout_B, | 
					
						
							|  |  |  |     transform_B, | 
					
						
							|  |  |  |     element_C | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Compute the largest alignment restriction the kernel can satisfy. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Maximum alignment expectation among all kernels (in units of bytes) | 
					
						
							|  |  |  |   int const kMaximumAlignmentSize = 16; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *ptr_A_check = ptr_A; | 
					
						
							|  |  |  |   void const *ptr_B_check = ptr_B; | 
					
						
							|  |  |  |   void const *ptr_C_check = ptr_C; | 
					
						
							|  |  |  |   void *      ptr_D_check = ptr_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Ignore alignment of pointers to pointers. We can't check this from the host, | 
					
						
							|  |  |  |   // as each batch index has its own pointer in device memory. | 
					
						
							|  |  |  |   if (mode == GemmUniversalMode::kArray) { | 
					
						
							|  |  |  |     ptr_A_check = nullptr;  | 
					
						
							|  |  |  |     ptr_B_check = nullptr;  | 
					
						
							|  |  |  |     ptr_C_check = nullptr;  | 
					
						
							|  |  |  |     ptr_D_check = nullptr;  | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int alignment = gemm_problem_alignment( | 
					
						
							|  |  |  |     M, N, K,  | 
					
						
							|  |  |  |     element_A, ptr_A_check, lda, 0, | 
					
						
							|  |  |  |     element_B, ptr_B_check, ldb, 0, | 
					
						
							|  |  |  |     element_C, ptr_C_check, ldc, 0, | 
					
						
							|  |  |  |     ptr_D_check, ldd, 0, kMaximumAlignmentSize | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the best kernel in descending order of preference. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPreferenceKey preference_key(compute_capability(), alignment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Operation const *operation = find_gemm_operation(operators_it, preference_key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (!operation) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   last_operation_ = operation; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Configure operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmUniversalConfiguration configuration{ | 
					
						
							|  |  |  |     mode, | 
					
						
							|  |  |  |     {M, N, K}, | 
					
						
							|  |  |  |     batch_count, | 
					
						
							|  |  |  |     lda, | 
					
						
							|  |  |  |     ldb, | 
					
						
							|  |  |  |     ldc, | 
					
						
							|  |  |  |     ldd | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query host work space size | 
					
						
							|  |  |  |   uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   char host_workspace[kHostWorkspaceSize]; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-03-23 00:36:34 +08:00
										 |  |  |   GemmUniversalArguments arguments{ | 
					
						
							|  |  |  |     ptr_A, | 
					
						
							|  |  |  |     ptr_B, | 
					
						
							|  |  |  |     ptr_C, | 
					
						
							|  |  |  |     ptr_D, | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     scalar_pointer_mode_, | 
					
						
							|  |  |  |     batch_stride_A, | 
					
						
							|  |  |  |     batch_stride_B, | 
					
						
							|  |  |  |     batch_stride_C, | 
					
						
							|  |  |  |     batch_stride_D | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   // Query device workspace size | 
					
						
							| 
									
										
										
										
											2022-03-23 00:36:34 +08:00
										 |  |  |   uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration, &arguments); | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(workspace_size_) < device_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Initialize host and device workspaces | 
					
						
							|  |  |  |   Status status = operation->initialize( | 
					
						
							|  |  |  |     &configuration, | 
					
						
							|  |  |  |     host_workspace, | 
					
						
							|  |  |  |     workspace_, | 
					
						
							|  |  |  |     stream_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Run the operator | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   return operation->run(&arguments, host_workspace, workspace_, stream_); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | /// Planar complex GEMM | 
					
						
							|  |  |  | Status Handle::gemm_planar_complex( | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int M,                                    /// GEMM M dimension | 
					
						
							|  |  |  |   int N,                                    /// GEMM N dimension | 
					
						
							|  |  |  |   int K,                                    /// GEMM K dimension | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_compute,            /// Data type of internal accumulation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_scalar,             /// Data type of alpha/beta scalars | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *alpha,                        /// Pointer to alpha scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_A,                  /// Data type of A matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_A,                    /// Layout of A matrix | 
					
						
							|  |  |  |   ComplexTransform transform_A,             /// Complex transformation applied to A matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_A_real,                  /// Pointer to real part of A matrix | 
					
						
							|  |  |  |   void const * ptr_A_imag,                  /// Pointer to imaginary part of A matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t lda_real,                         /// Leading dimension of real part of A matrix | 
					
						
							|  |  |  |   int64_t lda_imag,                         /// Leading dimension of imaginary part of A matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_B,                  /// Data type of B matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_B,                    /// Layout of B matrix | 
					
						
							|  |  |  |   ComplexTransform transform_B,             /// Complex transformation applied to B matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_B_real,                  /// Pointer to real part of B matrix | 
					
						
							|  |  |  |   void const * ptr_B_imag,                  /// Pointer to imaginary part of B matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldb_real,                             /// Leading dimension of real part of B matrix | 
					
						
							|  |  |  |   int64_t ldb_imag,                             /// Leading dimension of imaginary part of B matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void const * beta,                        /// Pointer to beta scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_C,                  /// Data type of C and D matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * ptr_C_real,                  /// Pointer to real part of C matrix | 
					
						
							|  |  |  |   void const * ptr_C_imag,                  /// Pointer to imaginary part of C matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldc_real,                             /// Leading dimension of real part of C matrix | 
					
						
							|  |  |  |   int64_t ldc_imag,                             /// Leading dimension of imaginary part of C matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void * ptr_D_real,                        /// Pointer to real part of D matrix | 
					
						
							|  |  |  |   void * ptr_D_imag,                        /// Pointer to imaginary part of D matrix | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldd_real,                             /// Leading dimension of real part of D matrix | 
					
						
							|  |  |  |   int64_t ldd_imag,                             /// Leading dimension of imaginary part of D matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   int batch_count,                          /// Number of batched GEMMs to execute | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t batch_stride_A_real, | 
					
						
							|  |  |  |   int64_t batch_stride_A_imag, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t batch_stride_B_real, | 
					
						
							|  |  |  |   int64_t batch_stride_B_imag, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t batch_stride_C_real, | 
					
						
							|  |  |  |   int64_t batch_stride_C_imag, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t batch_stride_D_real, | 
					
						
							|  |  |  |   int64_t batch_stride_D_imag | 
					
						
							|  |  |  | ) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmFunctionalKey key( | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |     provider_, | 
					
						
							|  |  |  |     GemmKind::kPlanarComplex, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     element_compute, | 
					
						
							|  |  |  |     element_scalar, | 
					
						
							|  |  |  |     element_A, | 
					
						
							|  |  |  |     layout_A, | 
					
						
							|  |  |  |     transform_A, | 
					
						
							|  |  |  |     element_B, | 
					
						
							|  |  |  |     layout_B, | 
					
						
							|  |  |  |     transform_B, | 
					
						
							|  |  |  |     element_C | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Compute the largest alignment restriction the kernel can satisfy. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Maximum alignment expectation among all kernels (in units of bytes) | 
					
						
							|  |  |  |   int const kMaximumAlignmentSize = 16; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int alignment = std::max( | 
					
						
							|  |  |  |     gemm_problem_alignment( | 
					
						
							|  |  |  |       M, N, K,  | 
					
						
							|  |  |  |       element_A, ptr_A_real, lda_real, batch_stride_A_real, | 
					
						
							|  |  |  |       element_B, ptr_B_real, ldb_real, batch_stride_B_real, | 
					
						
							|  |  |  |       element_C, ptr_C_real, ldc_real, batch_stride_C_real, | 
					
						
							|  |  |  |       ptr_D_real, ldd_real, batch_stride_D_real, kMaximumAlignmentSize | 
					
						
							|  |  |  |     ), | 
					
						
							|  |  |  |     gemm_problem_alignment( | 
					
						
							|  |  |  |       M, N, K,  | 
					
						
							|  |  |  |       element_A, ptr_A_imag, lda_imag, batch_stride_A_imag, | 
					
						
							|  |  |  |       element_B, ptr_B_imag, ldb_imag, batch_stride_B_imag, | 
					
						
							|  |  |  |       element_C, ptr_C_imag, ldc_imag, batch_stride_C_imag, | 
					
						
							|  |  |  |       ptr_D_imag, ldd_imag, batch_stride_D_imag, kMaximumAlignmentSize | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the best kernel in descending order of preference. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPreferenceKey preference_key(compute_capability(), alignment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Operation const *operation = find_gemm_operation(operators_it, preference_key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (!operation) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   last_operation_ = operation; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Configure operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPlanarComplexConfiguration configuration{ | 
					
						
							|  |  |  |     GemmUniversalMode::kBatched, | 
					
						
							|  |  |  |     {M, N, K}, | 
					
						
							|  |  |  |     batch_count, | 
					
						
							|  |  |  |     lda_real, | 
					
						
							|  |  |  |     lda_imag, | 
					
						
							|  |  |  |     ldb_real, | 
					
						
							|  |  |  |     ldb_imag, | 
					
						
							|  |  |  |     ldc_real, | 
					
						
							|  |  |  |     ldc_imag, | 
					
						
							|  |  |  |     ldd_real, | 
					
						
							|  |  |  |     ldd_imag | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query host work space size | 
					
						
							|  |  |  |   uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   char host_workspace[kHostWorkspaceSize]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query device workspace size | 
					
						
							|  |  |  |   uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(workspace_size_) < device_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Initialize host and device workspaces | 
					
						
							|  |  |  |   Status status = operation->initialize( | 
					
						
							|  |  |  |     &configuration, | 
					
						
							|  |  |  |     host_workspace, | 
					
						
							|  |  |  |     workspace_, | 
					
						
							|  |  |  |     stream_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Run the operator | 
					
						
							|  |  |  |   GemmPlanarComplexArguments arguments{ | 
					
						
							|  |  |  |     ptr_A_real, | 
					
						
							|  |  |  |     ptr_A_imag, | 
					
						
							|  |  |  |     ptr_B_real, | 
					
						
							|  |  |  |     ptr_B_imag, | 
					
						
							|  |  |  |     ptr_C_real, | 
					
						
							|  |  |  |     ptr_C_imag, | 
					
						
							|  |  |  |     ptr_D_real, | 
					
						
							|  |  |  |     ptr_D_imag, | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     scalar_pointer_mode_, | 
					
						
							|  |  |  |     batch_stride_A_real, | 
					
						
							|  |  |  |     batch_stride_A_imag, | 
					
						
							|  |  |  |     batch_stride_B_real, | 
					
						
							|  |  |  |     batch_stride_B_imag, | 
					
						
							|  |  |  |     batch_stride_C_real, | 
					
						
							|  |  |  |     batch_stride_C_imag, | 
					
						
							|  |  |  |     batch_stride_D_real, | 
					
						
							|  |  |  |     batch_stride_D_imag | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return operation->run(&arguments, host_workspace, workspace_, stream_); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Planar complex batched GEMM loading pointers from arrays in global memory | 
					
						
							|  |  |  | Status Handle::gemm_planar_complex_array( | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int expected_M,                           /// Expected GEMM M dimension (used for sizing CUDA grid) | 
					
						
							|  |  |  |   int expected_N,                           /// Expected GEMM N dimension (used for sizing CUDA grid) | 
					
						
							|  |  |  |   int expected_K,                           /// Expected GEMM K dimension | 
					
						
							|  |  |  |   int batch_count,                          /// Number of independent GEMM computations to execute | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int const *M,                             /// Array containing the GEMM M dimension for each batch index | 
					
						
							|  |  |  |   int const *N,                             /// Array containing the GEMM N dimension for each batch index | 
					
						
							|  |  |  |   int const *K,                             /// Array containing the GEMM K dimension for each batch index | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_compute,            /// Data type of internal accumulation | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_scalar,             /// Data type of alpha/beta scalars | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const *alpha,                        /// Pointer to alpha scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_A,                  /// Data type of A matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_A,                    /// Layout of A matrix | 
					
						
							|  |  |  |   ComplexTransform transform_A,             /// Complex transformation applied to A matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * const * ptr_A_real,          /// Pointer to array containing pointers to real part of A matrices | 
					
						
							|  |  |  |   void const * const * ptr_A_imag,          /// Pointer to array containing pointers to imaginary part of A matrices | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t lda_real,                             /// Leading dimension of real part of A matrix | 
					
						
							|  |  |  |   int64_t lda_imag,                             /// Leading dimension of imaginary part of A matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_B,                  /// Data type of B matrix elements | 
					
						
							|  |  |  |   LayoutTypeID layout_B,                    /// Layout of B matrix | 
					
						
							|  |  |  |   ComplexTransform transform_B,             /// Complex transformation applied to B matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * const * ptr_B_real,          /// Pointer to array containing pointers to real part of B matrices | 
					
						
							|  |  |  |   void const * const * ptr_B_imag,          /// Pointer to array containing pointers to imaginary part of B matrices | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldb_real,                             /// Leading dimension of real part of B matrix | 
					
						
							|  |  |  |   int64_t ldb_imag,                             /// Leading dimension of imaginary part of B matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void const * beta,                        /// Pointer to beta scalar | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   NumericTypeID element_C,                  /// Data type of C and D matrix | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   void const * const * ptr_C_real,          /// Pointer to array containing pointers to real part of C matrices | 
					
						
							|  |  |  |   void const * const * ptr_C_imag,          /// Pointer to array containing poitners to imaginary part of C matrices | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldc_real,                             /// Leading dimension of real part of C matrix | 
					
						
							|  |  |  |   int64_t ldc_imag,                             /// Leading dimension of imaginary part of C matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   void * const * ptr_D_real,                /// Pointer to array containing pointers to real part of D matrices | 
					
						
							|  |  |  |   void * const * ptr_D_imag,                /// Pointer to array containing poitners to imaginary part of D matrices | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |   int64_t ldd_real,                             /// Leading dimension of real part of D matrix | 
					
						
							|  |  |  |   int64_t ldd_imag                              /// Leading dimension of imaginary part of D matrix | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | ) { | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmFunctionalKey key( | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |     provider_, | 
					
						
							|  |  |  |     GemmKind::kPlanarComplexArray, | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     element_compute, | 
					
						
							|  |  |  |     element_scalar, | 
					
						
							|  |  |  |     element_A, | 
					
						
							|  |  |  |     layout_A, | 
					
						
							|  |  |  |     transform_A, | 
					
						
							|  |  |  |     element_B, | 
					
						
							|  |  |  |     layout_B, | 
					
						
							|  |  |  |     transform_B, | 
					
						
							|  |  |  |     element_C | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   auto operators_it = Singleton::get().operation_table.gemm_operations.find(key); | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-06-09 07:17:35 +08:00
										 |  |  |   if (operators_it == Singleton::get().operation_table.gemm_operations.end()) { | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Compute the largest alignment restriction the kernel can satisfy. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Maximum alignment expectation among all kernels (in units of bytes) | 
					
						
							|  |  |  |   int const kMaximumAlignmentSize = 16; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int alignment = std::max( | 
					
						
							|  |  |  |     gemm_problem_alignment( | 
					
						
							|  |  |  |       expected_M, expected_N, expected_K,  | 
					
						
							|  |  |  |       element_A, nullptr, lda_real, 0, | 
					
						
							|  |  |  |       element_B, nullptr, ldb_real, 0, | 
					
						
							|  |  |  |       element_C, nullptr, ldc_real, 0, | 
					
						
							|  |  |  |       nullptr, ldd_real, 0, kMaximumAlignmentSize | 
					
						
							|  |  |  |     ), | 
					
						
							|  |  |  |     gemm_problem_alignment( | 
					
						
							|  |  |  |       expected_M, expected_N, expected_K,  | 
					
						
							|  |  |  |       element_A, nullptr, lda_imag, 0, | 
					
						
							|  |  |  |       element_B, nullptr, ldb_imag, 0, | 
					
						
							|  |  |  |       element_C, nullptr, ldc_imag, 0, | 
					
						
							|  |  |  |       nullptr, ldd_imag, 0, kMaximumAlignmentSize | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Find the best kernel in descending order of preference. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPreferenceKey preference_key(compute_capability(), alignment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   Operation const *operation = find_gemm_operation(operators_it, preference_key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (!operation) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   last_operation_ = operation; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Configure operation | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmPlanarComplexArrayConfiguration configuration{ | 
					
						
							|  |  |  |     {expected_M, expected_N, expected_K}, | 
					
						
							|  |  |  |     batch_count, | 
					
						
							|  |  |  |     lda_real, | 
					
						
							|  |  |  |     lda_imag, | 
					
						
							|  |  |  |     ldb_real, | 
					
						
							|  |  |  |     ldb_imag, | 
					
						
							|  |  |  |     ldc_real, | 
					
						
							|  |  |  |     ldc_imag, | 
					
						
							|  |  |  |     ldd_real, | 
					
						
							|  |  |  |     ldd_imag | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query host work space size | 
					
						
							|  |  |  |   uint64_t host_workspace_size_needed = operation->get_host_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(kHostWorkspaceSize) < host_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   char host_workspace[kHostWorkspaceSize]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Query device workspace size | 
					
						
							|  |  |  |   uint64_t device_workspace_size_needed = operation->get_device_workspace_size(&configuration); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (uint64_t(workspace_size_) < device_workspace_size_needed) { | 
					
						
							|  |  |  |     return cutlass::Status::kErrorNotSupported; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Initialize host and device workspaces | 
					
						
							|  |  |  |   Status status = operation->initialize( | 
					
						
							|  |  |  |     &configuration, | 
					
						
							|  |  |  |     host_workspace, | 
					
						
							|  |  |  |     workspace_, | 
					
						
							|  |  |  |     stream_); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |     return status; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Run the operator | 
					
						
							|  |  |  |   GemmPlanarComplexArrayArguments arguments{ | 
					
						
							|  |  |  |     M, N, K, | 
					
						
							|  |  |  |     ptr_A_real, | 
					
						
							|  |  |  |     ptr_A_imag, | 
					
						
							|  |  |  |     ptr_B_real, | 
					
						
							|  |  |  |     ptr_B_imag, | 
					
						
							|  |  |  |     ptr_C_real, | 
					
						
							|  |  |  |     ptr_C_imag, | 
					
						
							|  |  |  |     ptr_D_real, | 
					
						
							|  |  |  |     ptr_D_imag, | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     scalar_pointer_mode_ | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return operation->run(&arguments, host_workspace, workspace_, stream_); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | /// Finds conv operation instances with Conv::ElementC = Reduction::ElementWorkspace | 
					
						
							|  |  |  | Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   ConvDescription const &conv_desc =  | 
					
						
							|  |  |  |     static_cast<ConvDescription const &>(operation->description()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // if the curren conv operation accumulator and output data type match return operation | 
					
						
							|  |  |  |   if(conv_desc.tile_description.math_instruction.element_accumulator == conv_desc.C.element) { | 
					
						
							|  |  |  |     return operation; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // find conv operation to match conv output and reduction workspace data type | 
					
						
							|  |  |  |   ConvFunctionalKey key( | 
					
						
							|  |  |  |     library::Provider::kCUTLASS, | 
					
						
							|  |  |  |     conv_desc.conv_kind,         | 
					
						
							|  |  |  |     conv_desc.A.element, | 
					
						
							|  |  |  |     conv_desc.A.layout, | 
					
						
							|  |  |  |     conv_desc.B.element, | 
					
						
							|  |  |  |     conv_desc.B.layout, | 
					
						
							|  |  |  |     conv_desc.tile_description.math_instruction.element_accumulator, | 
					
						
							|  |  |  |     conv_desc.C.layout, | 
					
						
							|  |  |  |     conv_desc.tile_description.math_instruction.element_accumulator,  | 
					
						
							|  |  |  |     conv_desc.element_epilogue); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // conv operation table for conv2d or conv3d | 
					
						
							|  |  |  |   auto conv_operations = (conv_desc.kind == OperationKind::kConv2d) ?  | 
					
						
							|  |  |  |                           Singleton::get().operation_table.conv2d_operations :  | 
					
						
							|  |  |  |                           Singleton::get().operation_table.conv3d_operations; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // find ConvFunctionalKey in convolution operation table | 
					
						
							|  |  |  |   auto operators_it = conv_operations.find(key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (operators_it == conv_operations.end()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // conv operation for same compute capability and iterator algorithm | 
					
						
							|  |  |  |   ConvPreferenceKey preference_key( | 
					
						
							|  |  |  |     conv_desc.tile_description.minimum_compute_capability,  | 
					
						
							|  |  |  |     conv_desc.iterator_algorithm); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   auto it = operators_it->second.find(preference_key); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   if(it == operators_it->second.end()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // return matching conv opertion (same tile sizes and instruction) | 
					
						
							|  |  |  |   for (auto op : it->second) { | 
					
						
							|  |  |  |     if (op->description().tile_description == operation->description().tile_description) { | 
					
						
							|  |  |  |       return op; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return nullptr; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2022-01-27 23:37:37 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace | 
					
						
							|  |  |  | Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmDescription const &gemm_desc =  | 
					
						
							|  |  |  |     static_cast<GemmDescription const &>(operation->description()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // if the curren gemm operation accumulator and output data type match return operation | 
					
						
							|  |  |  |   if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) { | 
					
						
							|  |  |  |     return operation; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // find gemm operation to match gemm output and reduction workspace data type | 
					
						
							|  |  |  |   GemmFunctionalKey key( | 
					
						
							|  |  |  |     library::Provider::kCUTLASS, | 
					
						
							|  |  |  |     gemm_desc.gemm_kind, | 
					
						
							|  |  |  |     gemm_desc.tile_description.math_instruction.element_accumulator, | 
					
						
							|  |  |  |     gemm_desc.element_epilogue, | 
					
						
							|  |  |  |     gemm_desc.A.element, | 
					
						
							|  |  |  |     gemm_desc.A.layout, | 
					
						
							|  |  |  |     gemm_desc.transform_A, | 
					
						
							|  |  |  |     gemm_desc.B.element, | 
					
						
							|  |  |  |     gemm_desc.B.layout, | 
					
						
							|  |  |  |     gemm_desc.transform_B, | 
					
						
							|  |  |  |     gemm_desc.tile_description.math_instruction.element_accumulator); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // gemm operation table | 
					
						
							|  |  |  |   auto gemm_operations = Singleton::get().operation_table.gemm_operations; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // find ConvFunctionalKey in gemm operation table | 
					
						
							|  |  |  |   auto operators_it = gemm_operations.find(key); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (operators_it == gemm_operations.end()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   if (operators_it->second.empty()) { | 
					
						
							|  |  |  |     return nullptr; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A and B uses the same alignment in the generator.py | 
					
						
							|  |  |  |   int alignment = gemm_desc.A.alignment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // gemm operation for same compute capability and iterator algorithm | 
					
						
							|  |  |  |   GemmPreferenceKey preference_key( | 
					
						
							|  |  |  |     gemm_desc.tile_description.minimum_compute_capability,  | 
					
						
							|  |  |  |     alignment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return find_gemm_operation(operators_it, preference_key); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-11-20 13:25:25 +08:00
										 |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-04-08 04:51:25 +08:00
										 |  |  | } // namespace library | 
					
						
							|  |  |  | } // namespace cutlass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// |