| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |     \brief Tests for device-wide GEMM interface | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <iostream>
 | 
					
						
							| 
									
										
										
										
											2022-07-20 03:23:54 +08:00
										 |  |  | #include <fstream>
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | #include "../../common/cutlass_unit_test.h"
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/gemm_grouped.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/kernel/default_gemm_grouped.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/device/gemm_grouped.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/util/host_tensor.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/gemm_complex.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_compare.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_copy.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_fill.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_norm.h"
 | 
					
						
							|  |  |  | #include "cutlass/util/tensor_view_io.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace test { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace device { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Gemm> | 
					
						
							|  |  |  | struct TestbedGrouped { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Type definitions
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using ElementA = typename Gemm::ElementA; | 
					
						
							|  |  |  |   using ElementB = typename Gemm::ElementB; | 
					
						
							|  |  |  |   using ElementC = typename Gemm::ElementC; | 
					
						
							|  |  |  |   using ElementAccumulator = typename Gemm::ElementAccumulator; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; | 
					
						
							|  |  |  |   using ElementCompute = typename EpilogueOutputOp::ElementCompute; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using LayoutA = typename Gemm::LayoutA; | 
					
						
							|  |  |  |   using LayoutB = typename Gemm::LayoutB; | 
					
						
							|  |  |  |   using LayoutC = typename Gemm::LayoutC; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using MatrixCoord = typename LayoutC::TensorCoord; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Data members
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initialization
 | 
					
						
							|  |  |  |   cutlass::Distribution::Kind init_A; | 
					
						
							|  |  |  |   cutlass::Distribution::Kind init_B; | 
					
						
							|  |  |  |   cutlass::Distribution::Kind init_C; | 
					
						
							|  |  |  |   uint32_t seed; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   int problem_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   std::vector<cutlass::gemm::GemmCoord>               problem_sizes_host; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<cutlass::gemm::GemmCoord> problem_sizes_device; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   std::vector<int64_t> offset_A; | 
					
						
							|  |  |  |   std::vector<int64_t> offset_B; | 
					
						
							|  |  |  |   std::vector<int64_t> offset_C; | 
					
						
							|  |  |  |   std::vector<int64_t> offset_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   std::vector<int64_t> lda_host; | 
					
						
							|  |  |  |   std::vector<int64_t> ldb_host; | 
					
						
							|  |  |  |   std::vector<int64_t> ldc_host; | 
					
						
							|  |  |  |   std::vector<int64_t> ldd_host; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int64_t> lda; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int64_t> ldb; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int64_t> ldc; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<int64_t> ldd; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementA> block_A; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementB> block_B; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementC> block_C; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementC> block_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementA *> ptr_A; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementB *> ptr_B; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementC *> ptr_C; | 
					
						
							|  |  |  |   cutlass::DeviceAllocation<ElementC *> ptr_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Methods
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   TestbedGrouped( | 
					
						
							|  |  |  |     cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, | 
					
						
							|  |  |  |     cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, | 
					
						
							|  |  |  |     cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, | 
					
						
							|  |  |  |     uint32_t seed_ = 3080 | 
					
						
							|  |  |  |   ): | 
					
						
							|  |  |  |     init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Helper to initialize a tensor view
 | 
					
						
							|  |  |  |   template <typename Element, typename Layout> | 
					
						
							|  |  |  |   bool initialize_tensor( | 
					
						
							|  |  |  |     cutlass::TensorView<Element, Layout> view,  | 
					
						
							|  |  |  |     cutlass::Distribution::Kind dist_kind, | 
					
						
							|  |  |  |     uint32_t seed) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (dist_kind == cutlass::Distribution::Uniform) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       double scope_max, scope_min; | 
					
						
							|  |  |  |       int bits_input = cutlass::sizeof_bits<Element>::value; | 
					
						
							|  |  |  |       int bits_output = cutlass::sizeof_bits<typename Gemm::ElementC>::value; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (bits_input == 1) { | 
					
						
							|  |  |  |         scope_max = 2; | 
					
						
							|  |  |  |         scope_min = 0; | 
					
						
							|  |  |  |       } else if (bits_input <= 8) { | 
					
						
							|  |  |  |         scope_max = 2; | 
					
						
							|  |  |  |         scope_min = -2; | 
					
						
							|  |  |  |       } else if (bits_output == 16) { | 
					
						
							|  |  |  |         if (cutlass::sizeof_bits<ElementAccumulator>::value <= 16) { | 
					
						
							|  |  |  |           scope_max = 5; | 
					
						
							|  |  |  |           scope_min = -5; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         else { | 
					
						
							|  |  |  |           scope_max = 8; | 
					
						
							|  |  |  |           scope_min = -8; | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |       } else { | 
					
						
							|  |  |  |         scope_max = 8; | 
					
						
							|  |  |  |         scope_min = -8; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::reference::host::TensorFillRandomUniform( | 
					
						
							|  |  |  |         view, seed, scope_max, scope_min, 0); | 
					
						
							|  |  |  |     }  | 
					
						
							|  |  |  |     else if (dist_kind == cutlass::Distribution::Identity) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::reference::host::TensorFillIdentity(view); | 
					
						
							|  |  |  |     }  | 
					
						
							|  |  |  |     else if (dist_kind == cutlass::Distribution::Gaussian) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     else if (dist_kind == cutlass::Distribution::Sequential) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::reference::host::BlockFillSequential( | 
					
						
							|  |  |  |         view.data(), view.capacity()); | 
					
						
							|  |  |  |     }  | 
					
						
							|  |  |  |     else { | 
					
						
							|  |  |  |       // no fill - remain zero
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return true; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Initializes data structures
 | 
					
						
							|  |  |  |   void initialize() { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Choose random problem sizes
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // construct a few problems of random sizes
 | 
					
						
							|  |  |  |     srand(seed); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     int64_t total_elements_A = 0; | 
					
						
							|  |  |  |     int64_t total_elements_B = 0; | 
					
						
							|  |  |  |     int64_t total_elements_C = 0; | 
					
						
							|  |  |  |     int64_t total_elements_D = 0; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     lda_host.resize(problem_count); | 
					
						
							|  |  |  |     ldb_host.resize(problem_count); | 
					
						
							|  |  |  |     ldc_host.resize(problem_count); | 
					
						
							|  |  |  |     ldd_host.resize(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     problem_sizes_host.clear(); | 
					
						
							|  |  |  |     problem_sizes_host.resize(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem( | 
					
						
							|  |  |  |         8 * (rand() % 64) + 24, | 
					
						
							|  |  |  |         8 * (rand() % 64) + 24, | 
					
						
							|  |  |  |         8 * (rand() % 64) + 24); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (!i) { | 
					
						
							|  |  |  |         problem = cutlass::gemm::GemmCoord(48, 16, 8); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       problem_sizes_host.at(i) = problem; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // std::cout << "Problem[" << i << "]: " << problem << std::endl;
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); | 
					
						
							|  |  |  |       ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); | 
					
						
							|  |  |  |       ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); | 
					
						
							|  |  |  |       ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       offset_A.push_back(total_elements_A); | 
					
						
							|  |  |  |       offset_B.push_back(total_elements_B); | 
					
						
							|  |  |  |       offset_C.push_back(total_elements_C); | 
					
						
							|  |  |  |       offset_D.push_back(total_elements_D); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       int64_t elements_A = problem.m() * problem.k(); | 
					
						
							|  |  |  |       int64_t elements_B = problem.k() * problem.n(); | 
					
						
							|  |  |  |       int64_t elements_C = problem.m() * problem.n(); | 
					
						
							|  |  |  |       int64_t elements_D = problem.m() * problem.n(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       total_elements_A += elements_A; | 
					
						
							|  |  |  |       total_elements_B += elements_B; | 
					
						
							|  |  |  |       total_elements_C += elements_C; | 
					
						
							|  |  |  |       total_elements_D += elements_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Random strides between problems?
 | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     problem_sizes_device.reset(problem_count); | 
					
						
							|  |  |  |     problem_sizes_device.copy_from_host(problem_sizes_host.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     lda.reset(problem_count); | 
					
						
							|  |  |  |     ldb.reset(problem_count); | 
					
						
							|  |  |  |     ldc.reset(problem_count); | 
					
						
							|  |  |  |     ldd.reset(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     lda.copy_from_host(lda_host.data()); | 
					
						
							|  |  |  |     ldb.copy_from_host(ldb_host.data()); | 
					
						
							|  |  |  |     ldc.copy_from_host(ldc_host.data()); | 
					
						
							|  |  |  |     ldd.copy_from_host(ldd_host.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Assign pointers
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     block_A.reset(total_elements_A); | 
					
						
							|  |  |  |     block_B.reset(total_elements_B); | 
					
						
							|  |  |  |     block_C.reset(total_elements_C); | 
					
						
							|  |  |  |     block_D.reset(total_elements_D); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::vector<ElementA *> ptr_A_host(problem_count); | 
					
						
							|  |  |  |     std::vector<ElementB *> ptr_B_host(problem_count); | 
					
						
							|  |  |  |     std::vector<ElementC *> ptr_C_host(problem_count); | 
					
						
							|  |  |  |     std::vector<ElementC *> ptr_D_host(problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  |       ptr_A_host.at(i) = block_A.get() + offset_A.at(i); | 
					
						
							|  |  |  |       ptr_B_host.at(i) = block_B.get() + offset_B.at(i); | 
					
						
							|  |  |  |       ptr_C_host.at(i) = block_C.get() + offset_C.at(i); | 
					
						
							|  |  |  |       ptr_D_host.at(i) = block_D.get() + offset_D.at(i); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ptr_A.reset(problem_count); | 
					
						
							|  |  |  |     ptr_A.copy_from_host(ptr_A_host.data()); | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     ptr_B.reset(problem_count); | 
					
						
							|  |  |  |     ptr_B.copy_from_host(ptr_B_host.data()); | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     ptr_C.reset(problem_count); | 
					
						
							|  |  |  |     ptr_C.copy_from_host(ptr_C_host.data()); | 
					
						
							|  |  |  |      | 
					
						
							|  |  |  |     ptr_D.reset(problem_count); | 
					
						
							|  |  |  |     ptr_D.copy_from_host(ptr_D_host.data()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  |     // Initialize the problems of the workspace
 | 
					
						
							|  |  |  |     //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       LayoutA layout_A(lda_host.at(i)); | 
					
						
							|  |  |  |       LayoutB layout_B(ldb_host.at(i)); | 
					
						
							|  |  |  |       LayoutC layout_C(ldc_host.at(i)); | 
					
						
							|  |  |  |       LayoutC layout_D(ldd_host.at(i)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       MatrixCoord extent_A{problem.m(), problem.k()}; | 
					
						
							|  |  |  |       MatrixCoord extent_B{problem.k(), problem.n()}; | 
					
						
							|  |  |  |       MatrixCoord extent_C{problem.m(), problem.n()}; | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |       std::vector<ElementA> matrix_A(layout_A.capacity(extent_A)); | 
					
						
							|  |  |  |       std::vector<ElementB> matrix_B(layout_B.capacity(extent_B)); | 
					
						
							|  |  |  |       std::vector<ElementC> matrix_C(layout_C.capacity(extent_C)); | 
					
						
							|  |  |  |       std::vector<ElementC> matrix_D(layout_D.capacity(extent_C)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       initialize_tensor(cutlass::TensorView<ElementA, LayoutA>(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); | 
					
						
							|  |  |  |       initialize_tensor(cutlass::TensorView<ElementB, LayoutB>(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); | 
					
						
							|  |  |  |       initialize_tensor(cutlass::TensorView<ElementC, LayoutC>(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Verifies the result is a GEMM
 | 
					
						
							|  |  |  |   bool verify( | 
					
						
							|  |  |  |     ElementCompute alpha,  | 
					
						
							|  |  |  |     ElementCompute beta) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     bool passed = true; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     for (int32_t i = 0; i < problem_count; ++i) { | 
					
						
							|  |  |  |       cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       LayoutA layout_A(lda_host.at(i)); | 
					
						
							|  |  |  |       LayoutB layout_B(ldb_host.at(i)); | 
					
						
							|  |  |  |       LayoutC layout_C(ldc_host.at(i)); | 
					
						
							|  |  |  |       LayoutC layout_D(ldd_host.at(i)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       MatrixCoord extent_A{problem.m(), problem.k()}; | 
					
						
							|  |  |  |       MatrixCoord extent_B{problem.k(), problem.n()}; | 
					
						
							|  |  |  |       MatrixCoord extent_C{problem.m(), problem.n()}; | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |       std::vector<ElementA> matrix_A(layout_A.capacity(extent_A)); | 
					
						
							|  |  |  |       std::vector<ElementB> matrix_B(layout_B.capacity(extent_B)); | 
					
						
							|  |  |  |       std::vector<ElementC> matrix_C(layout_C.capacity(extent_C)); | 
					
						
							|  |  |  |       std::vector<ElementC> matrix_D(layout_D.capacity(extent_C)); | 
					
						
							|  |  |  |       std::vector<ElementC> matrix_Ref(layout_D.capacity(extent_C)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); | 
					
						
							|  |  |  |       cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       cutlass::TensorView<ElementA, LayoutA> view_A(matrix_A.data(), layout_A, extent_A); | 
					
						
							|  |  |  |       cutlass::TensorView<ElementB, LayoutB> view_B(matrix_B.data(), layout_B, extent_B); | 
					
						
							|  |  |  |       cutlass::TensorView<ElementC, LayoutC> view_C(matrix_C.data(), layout_C, extent_C); | 
					
						
							|  |  |  |       cutlass::TensorView<ElementC, LayoutC> view_D(matrix_D.data(), layout_D, extent_C); | 
					
						
							|  |  |  |       cutlass::TensorView<ElementC, LayoutC> view_Ref(matrix_Ref.data(), layout_D, extent_C); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Reference GEMM
 | 
					
						
							|  |  |  |       cutlass::reference::host::GemmComplex< | 
					
						
							|  |  |  |           ElementA, LayoutA, | 
					
						
							|  |  |  |           ElementB, LayoutB, | 
					
						
							|  |  |  |           ElementC, LayoutC,  | 
					
						
							|  |  |  |           ElementCompute, ElementAccumulator | 
					
						
							|  |  |  |       >( | 
					
						
							|  |  |  |         problem, | 
					
						
							|  |  |  |         alpha,  | 
					
						
							|  |  |  |         view_A, | 
					
						
							|  |  |  |         Gemm::kTransformA, | 
					
						
							|  |  |  |         view_B, | 
					
						
							|  |  |  |         Gemm::kTransformB, | 
					
						
							|  |  |  |         beta,  | 
					
						
							|  |  |  |         view_C,  | 
					
						
							|  |  |  |         view_Ref,  | 
					
						
							|  |  |  |         ElementAccumulator(0) | 
					
						
							|  |  |  |       ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Ensure that no input or output is entirely zero
 | 
					
						
							|  |  |  |       EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); | 
					
						
							|  |  |  |       EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); | 
					
						
							|  |  |  |       EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); | 
					
						
							|  |  |  |       EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); | 
					
						
							|  |  |  |       EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compare against reference
 | 
					
						
							|  |  |  |       passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       if (!passed) { | 
					
						
							|  |  |  |         std::ofstream file("testbed_grouped_errors.txt"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         file | 
					
						
							|  |  |  |           << "problem: " << problem << "  [group: " << i << "]\n"  | 
					
						
							|  |  |  |           << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         file  | 
					
						
							|  |  |  |           << "A =\n" << view_A | 
					
						
							|  |  |  |           << "\nB =\n" << view_B | 
					
						
							|  |  |  |           << "\nC =\n" << view_C | 
					
						
							|  |  |  |           << "\n\nReference =\n" << view_Ref | 
					
						
							|  |  |  |           << "\nComputed =\n" << view_D; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return passed; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return passed; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   /// Executes one test
 | 
					
						
							|  |  |  |   bool run( | 
					
						
							|  |  |  |     int problem_count, | 
					
						
							|  |  |  |     ElementCompute alpha = ElementCompute(1),  | 
					
						
							|  |  |  |     ElementCompute beta = ElementCompute(0)) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     this->problem_count = problem_count; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Initialize the problem
 | 
					
						
							|  |  |  |     initialize(); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Early exit
 | 
					
						
							|  |  |  |     if (!threadblock_count) { | 
					
						
							|  |  |  |       if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { | 
					
						
							|  |  |  |         std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |       return true; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     // Configure the GEMM arguments
 | 
					
						
							|  |  |  |     typename EpilogueOutputOp::Params epilogue_op(alpha, beta); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Configure GEMM arguments
 | 
					
						
							|  |  |  |     typename Gemm::Arguments args( | 
					
						
							|  |  |  |       problem_sizes_device.get(), | 
					
						
							|  |  |  |       problem_count, | 
					
						
							|  |  |  |       threadblock_count, | 
					
						
							|  |  |  |       epilogue_op, | 
					
						
							|  |  |  |       ptr_A.get(), | 
					
						
							|  |  |  |       ptr_B.get(), | 
					
						
							|  |  |  |       ptr_C.get(), | 
					
						
							|  |  |  |       ptr_D.get(), | 
					
						
							|  |  |  |       lda.get(), | 
					
						
							|  |  |  |       ldb.get(), | 
					
						
							|  |  |  |       ldc.get(), | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |       ldd.get(), | 
					
						
							|  |  |  |       problem_sizes_host.data() | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  |     ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Initialize the GEMM object
 | 
					
						
							|  |  |  |     Gemm gemm; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-09-04 06:48:46 +08:00
										 |  |  |     size_t workspace_size = gemm.get_workspace_size(args); | 
					
						
							|  |  |  |     cutlass::DeviceAllocation<uint8_t> workspace(workspace_size); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     cutlass::Status status = gemm.initialize(args, workspace.get()); | 
					
						
							| 
									
										
										
										
											2021-11-20 05:26:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Run the GEMM object
 | 
					
						
							|  |  |  |     status = gemm.run(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (status != cutlass::Status::kSuccess) { | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Wait for completion
 | 
					
						
							|  |  |  |     cudaError_t result = cudaDeviceSynchronize(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     EXPECT_EQ(result, cudaSuccess)  | 
					
						
							|  |  |  |       << "Kernel execution error: " << cudaGetErrorString(result); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if (result != cudaSuccess) { | 
					
						
							|  |  |  |       return false; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Verify correctness
 | 
					
						
							|  |  |  |     return verify(alpha, beta); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // device
 | 
					
						
							|  |  |  | } // gemm
 | 
					
						
							|  |  |  | } // test
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////
 |