| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | /*************************************************************************************************** | 
					
						
							| 
									
										
										
										
											2023-01-21 05:32:57 +08:00
										 |  |  |  * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file | 
					
						
							|  |  |  |     \brief Tests for device-wide GEMM interface | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <iostream> | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "../../common/cutlass_unit_test.h" | 
					
						
							|  |  |  | #include "cutlass/cutlass.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm.h" | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/gemm_grouped.h" | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/default_gemm_grouped.h" | 
					
						
							|  |  |  | #include "cutlass/gemm/device/gemm_grouped.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/util/host_tensor.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/gemm.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_compare.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_copy.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_fill.h" | 
					
						
							|  |  |  | #include "cutlass/util/tensor_view_io.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "testbed_grouped.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Visitor class to abstract away the algorithm for iterating over tiles. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This is the prototype. We will delete this when the efficient kernel is | 
					
						
							|  |  |  | // available. | 
					
						
							|  |  |  | struct GemmGroupedProblemVisitor { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct Params { | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord const *problem_sizes; | 
					
						
							|  |  |  |     int32_t                         problem_count; | 
					
						
							|  |  |  |     int64_t const                  *tile_count; | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   struct SharedStorage { | 
					
						
							|  |  |  |     // | 
					
						
							|  |  |  |     // Nothing for now. As an optimization step, we could consider parallel | 
					
						
							|  |  |  |     // argmin or prefix sums across the block. | 
					
						
							|  |  |  |     // | 
					
						
							|  |  |  |   }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Data members | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   SharedStorage &shared_storage; | 
					
						
							|  |  |  |   Params const ¶ms; | 
					
						
							|  |  |  |   cutlass::MatrixCoord threadblock_shape; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int64_t tile_idx; | 
					
						
							|  |  |  |   int64_t tile_count_sum; | 
					
						
							|  |  |  |   int64_t problem_tile_start; | 
					
						
							|  |  |  |   int32_t problem_idx; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Methods | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   GemmGroupedProblemVisitor( | 
					
						
							|  |  |  |     SharedStorage &shared_storage_,  | 
					
						
							|  |  |  |     Params const ¶ms_, | 
					
						
							|  |  |  |     cutlass::MatrixCoord threadblock_shape_, | 
					
						
							|  |  |  |     int32_t block_idx | 
					
						
							|  |  |  |   ): | 
					
						
							|  |  |  |     shared_storage(shared_storage_), | 
					
						
							|  |  |  |     params(params_), | 
					
						
							|  |  |  |     threadblock_shape(threadblock_shape_), | 
					
						
							|  |  |  |     tile_idx(block_idx), | 
					
						
							|  |  |  |     tile_count_sum(0), | 
					
						
							|  |  |  |     problem_idx(0) | 
					
						
							|  |  |  |   { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord  grid = grid_shape(problem); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     problem_tile_start = 0; | 
					
						
							|  |  |  |     tile_count_sum = grid.m() * grid.n(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Get the grid shape | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   static cutlass::gemm::GemmCoord grid_shape( | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord const &problem, | 
					
						
							|  |  |  |     cutlass::MatrixCoord const & block_shape) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return cutlass::gemm::GemmCoord( | 
					
						
							|  |  |  |       ((problem.m() - 1 + block_shape.row()) / block_shape.row()), | 
					
						
							|  |  |  |       ((problem.n() - 1 + block_shape.column()) / block_shape.column()), | 
					
						
							|  |  |  |       1); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Get the grid shape | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const &problem) const { | 
					
						
							|  |  |  |     return grid_shape(problem, threadblock_shape); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Returns true if there is a tile to compute | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   bool next_tile() { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (tile_idx < tile_count_sum) { | 
					
						
							|  |  |  |       return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     do { | 
					
						
							|  |  |  |       ++problem_idx; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (problem_idx >= params.problem_count) { | 
					
						
							|  |  |  |         return false; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem = params.problem_sizes[problem_idx]; | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord  grid = grid_shape(problem); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int64_t tile_count = grid.m() * grid.n(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       problem_tile_start = tile_count_sum; | 
					
						
							|  |  |  |       tile_count_sum += tile_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     } while (tile_count_sum <= tile_idx); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Gets the global tile index | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   int64_t tile_index() const { | 
					
						
							|  |  |  |     return tile_idx; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Gets the index of the problem | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   int32_t problem_index() const { | 
					
						
							|  |  |  |     return problem_idx; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Returns the problem size for the current problem | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   cutlass::gemm::GemmCoord problem_size() const { | 
					
						
							|  |  |  |     return params.problem_sizes[problem_idx]; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |   int64_t threadblock_idx() const { | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     return tile_idx - problem_tile_start; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   void advance(int32_t grid_size) { | 
					
						
							|  |  |  |     tile_idx += grid_size;  | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  | template <int ThreadblockShapeM, int ThreadblockShapeN> | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | __global__ void GroupedBatchedKernel(GemmGroupedProblemVisitor::Params params) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   __shared__ GemmGroupedProblemVisitor::SharedStorage shared_storage; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmGroupedProblemVisitor problem_visitor( | 
					
						
							|  |  |  |     shared_storage,  | 
					
						
							|  |  |  |     params,  | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     {ThreadblockShapeM, ThreadblockShapeN},  | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     blockIdx.x); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   while (problem_visitor.next_tile()) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     int64_t threadblock_idx                       = problem_visitor.threadblock_idx(); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     int threadblock_tile_m_idx = int(threadblock_idx / grid_shape.n()); | 
					
						
							|  |  |  |     int threadblock_tile_n_idx = int(threadblock_idx % grid_shape.n()); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     // | 
					
						
							|  |  |  |     // Do the MMA | 
					
						
							|  |  |  |     // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (threadIdx.x == 0) { | 
					
						
							|  |  |  |       #if 0 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |       printf("Block %d - tile: %lld, problem %d, threadblock_idx: %lld, threadblock(m: %d, n: %d)\n",  | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |         blockIdx.x,  | 
					
						
							|  |  |  |         problem_visitor.tile_index(),  | 
					
						
							|  |  |  |         problem_visitor.problem_index(),  | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |         threadblock_idx,  | 
					
						
							|  |  |  |         threadblock_tile_m_idx,  | 
					
						
							|  |  |  |         threadblock_tile_n_idx); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |       #endif | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Next tile | 
					
						
							|  |  |  |     problem_visitor.advance(gridDim.x); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST(SM80_Device_GemmGrouped_scheduler, 64x64x32_32x32x32) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int32_t problem_count = 16; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |   int const kThreadblockShapeM = 64; | 
					
						
							|  |  |  |   int const kThreadblockShapeN = 64; | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   std::vector<cutlass::gemm::GemmCoord> problem_sizes(problem_count); | 
					
						
							|  |  |  |   std::vector<int64_t> tile_counts(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // construct a few problems of random sizes | 
					
						
							|  |  |  |   srand(1921); | 
					
						
							|  |  |  |   for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  |     problem_sizes.at(i) = cutlass::gemm::GemmCoord( | 
					
						
							|  |  |  |       8 * (rand() % 48) + 64, | 
					
						
							|  |  |  |       8 * (rand() % 48) + 64, | 
					
						
							|  |  |  |       8 * (rand() % 48) + 64); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // compute prefix sum | 
					
						
							|  |  |  |   int64_t tile_count = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::gemm::GemmCoord grid_shape = GemmGroupedProblemVisitor::grid_shape( | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |       problem_sizes.at(i), {kThreadblockShapeM, kThreadblockShapeN}); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     int32_t problem_tile_count = (grid_shape.m() * grid_shape.n()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t tile_start = tile_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     tile_count += problem_tile_count; | 
					
						
							|  |  |  |     tile_counts.at(i) = tile_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (false) { | 
					
						
							|  |  |  |       std::cout << "Problem " << i << " size("  | 
					
						
							|  |  |  |         << problem_sizes.at(i).m() << "-by-" << problem_sizes.at(i).n()  | 
					
						
							|  |  |  |         << ") - tiles: " << problem_tile_count << ",  grid(" << grid_shape.m() << ", " << grid_shape.n()  | 
					
						
							|  |  |  |         << "), tiles[" << tile_start << ", " << tile_count << ")" << std::endl;   | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Copy to device memory | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device(problem_count); | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int64_t>                  tile_counts_device(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   problem_sizes_device.copy_from_host(problem_sizes.data()); | 
					
						
							|  |  |  |   tile_counts_device.copy_from_host(tile_counts.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   GemmGroupedProblemVisitor::Params params; | 
					
						
							|  |  |  |   params.problem_sizes = problem_sizes_device.get(); | 
					
						
							|  |  |  |   params.problem_count = problem_count; | 
					
						
							|  |  |  |   params.tile_count = tile_counts_device.get(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Launch the kernel | 
					
						
							|  |  |  |   dim3 grid(108, 1, 1); | 
					
						
							|  |  |  |   dim3 block(128, 1, 1); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |   GroupedBatchedKernel<kThreadblockShapeM, kThreadblockShapeN><<< grid, block >>>(params); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // wait | 
					
						
							|  |  |  |   cudaDeviceSynchronize(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST(SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32, 128x128x32_64x64x32) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     cutlass::half_t,  | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 128, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 32>,  | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 16>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(24); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_f16n_f16t_f32t_tensor_op_f32, 128x128x32_64x64x32) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::RowMajor,    // row major | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 128, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 16>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(24); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_f16t_f16n_f32n_tensor_op_f32, 128x64x32_64x32x32) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementOutput = cutlass::half_t; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     cutlass::half_t,  | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 64, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 32>,  | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 16>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     4>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_f16t_f16n_f32t_tensor_op_f32, 128x64x32_64x32x32) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementOutput = cutlass::half_t; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     cutlass::half_t, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     8, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 64, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 32>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 16>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     4>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_f64t_f64t_f64n_tensor_op_f64, 64x64x16_32x32x16) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = double; | 
					
						
							|  |  |  |   using ElementOutput = double; | 
					
						
							|  |  |  |   using ElementAccumulator = double; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput,  | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<32, 32, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<8, 8, 4>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     4>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x128x8_64x32x1) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = float; | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput,  | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassSimt,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 128, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<1, 1, 1>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x128x8_64x32x1) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = float; | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassSimt, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 128, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<1, 1, 1>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST(SM80_Device_GemmGrouped_f32t_f32t_f32n_simt_f32, 128x64x8_64x32x1) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = float; | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassSimt, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 64, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<1, 1, 1>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | TEST(SM80_Device_GemmGrouped_f32t_f32t_f32t_simt_f32, 128x64x8_64x32x1) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = float; | 
					
						
							|  |  |  |   using ElementOutput = float; | 
					
						
							|  |  |  |   using ElementAccumulator = float; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassSimt, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<128, 64, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 32, 8>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<1, 1, 1>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     3>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_cf32n_cf32n_cf32n_tensorop_f32, 64x64x16_32x32x16) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementOutput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementAccumulator = cutlass::complex<float>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput,  | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<32, 32, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 8>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     3, | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     cutlass::arch::OpMultiplyAddComplex>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32n_tensorop_f32, 64x64x16_32x32x16) { | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementOutput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementAccumulator = cutlass::complex<float>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput,  | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kConjugate, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kConjugate, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<32, 32, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 8>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     3, | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     cutlass::arch::OpMultiplyAddComplex>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_cf32c_cf32t_cf32t_tensorop_f32, 64x64x16_32x32x16) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementOutput = cutlass::complex<float>; | 
					
						
							|  |  |  |   using ElementAccumulator = cutlass::complex<float>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kConjugate, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kConjugate, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementAccumulator, | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp, | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<64, 64, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<32, 32, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 8, 8>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, | 
					
						
							|  |  |  |     3, | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |     cutlass::arch::OpMultiplyAddComplex>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | TEST(SM80_Device_GemmGrouped_cf32t_cf32h_cf32n_tensorop_f32, 64x64x16_16x16x16) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementInput = cutlass::complex<double>; | 
					
						
							|  |  |  |   using ElementOutput = cutlass::complex<double>; | 
					
						
							|  |  |  |   using ElementAccumulator = cutlass::complex<double>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< | 
					
						
							|  |  |  |     ElementInput,  | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kNone, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementInput, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |     cutlass::ComplexTransform::kConjugate, | 
					
						
							|  |  |  |     1, | 
					
						
							|  |  |  |     ElementOutput, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementAccumulator,  | 
					
						
							|  |  |  |     cutlass::arch::OpClassTensorOp,  | 
					
						
							|  |  |  |     cutlass::arch::Sm80, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<32, 32, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 16, 16>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<8, 8, 4>, | 
					
						
							|  |  |  |     cutlass::epilogue::thread::LinearCombination< | 
					
						
							|  |  |  |         ElementOutput, 1, | 
					
						
							|  |  |  |         ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,  | 
					
						
							|  |  |  |     3, | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     cutlass::arch::OpMultiplyAddComplex>::GemmKernel; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Gemm = cutlass::gemm::device::GemmGrouped<GemmKernel>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Test | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   test::gemm::device::TestbedGrouped<Gemm> testbed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   bool passed = testbed.run(27); | 
					
						
							|  |  |  |   EXPECT_TRUE(passed); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////////////////////// |