| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  | /*************************************************************************************************** | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |  * | 
					
						
							| 
									
										
										
										
											2022-04-24 03:02:38 +08:00
										 |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							| 
									
										
										
										
											2021-07-23 12:40:53 +08:00
										 |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /* | 
					
						
							|  |  |  |   This example requires NVIDIA Maxwell GPU or beyond. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Standard Library includes | 
					
						
							|  |  |  | #include <iostream> | 
					
						
							|  |  |  | #include <sstream> | 
					
						
							|  |  |  | #include <vector> | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // CUTLASS Includes | 
					
						
							|  |  |  | #include "cutlass/cutlass.h" | 
					
						
							|  |  |  | #include "cutlass/core_io.h" | 
					
						
							|  |  |  | #include "cutlass/functional.h" | 
					
						
							|  |  |  | #include "cutlass/layout/matrix.h" | 
					
						
							|  |  |  | #include "cutlass/gemm/warp/mma_simt.h" | 
					
						
							|  |  |  | #include "cutlass/epilogue/warp/fragment_iterator_simt.h" | 
					
						
							|  |  |  | #include "cutlass/epilogue/warp/tile_iterator_simt.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // CUTLASS Utility Includes | 
					
						
							|  |  |  | #include "cutlass/util/host_tensor.h" | 
					
						
							|  |  |  | #include "cutlass/util/tensor_view_io.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/gemm.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_compare.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_fill.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/tensor_copy.h" | 
					
						
							|  |  |  | #include "cutlass/util/reference/host/gemm_complex.h" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Define the overal warp-level problem shape | 
					
						
							|  |  |  | int const kM = 14; | 
					
						
							|  |  |  | int const kN = 27; | 
					
						
							|  |  |  | int const kK = 17; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Define a warp-level GEMM operator. | 
					
						
							|  |  |  | // | 
					
						
							|  |  |  | // This template could be part of the CUTLASS Template Library or implemented internally. This | 
					
						
							|  |  |  | // wraps the matrix multiply operation and epilogue with a GEMM-like interface that can be | 
					
						
							|  |  |  | // instantiated in device code. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace warp { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   typename Shape, | 
					
						
							|  |  |  |   typename ElementA, | 
					
						
							|  |  |  |   typename LayoutA, | 
					
						
							|  |  |  |   typename ElementB, | 
					
						
							|  |  |  |   typename LayoutB, | 
					
						
							|  |  |  |   typename ElementC, | 
					
						
							|  |  |  |   typename LayoutC, | 
					
						
							|  |  |  |   typename ElementScalar | 
					
						
							|  |  |  | > | 
					
						
							|  |  |  | class GemmSimt { | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using Policy = cutlass::gemm::warp::MmaSimtPolicy< | 
					
						
							|  |  |  |     cutlass::MatrixShape<4, 8>, | 
					
						
							|  |  |  |     cutlass::layout::RowMajorInterleaved<2>, | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<4, 4, 1> | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using MmaWarp = cutlass::gemm::warp::MmaSimt< | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<16, 32, 8>, | 
					
						
							|  |  |  |     float, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     float, | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     float, | 
					
						
							|  |  |  |     cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     Policy | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Number of 'K groups' | 
					
						
							|  |  |  |   int const kKgroups = Shape::kK; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using FragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< | 
					
						
							|  |  |  |     typename MmaWarp::Shape, | 
					
						
							|  |  |  |     typename MmaWarp::ThreadMma, | 
					
						
							|  |  |  |     layout::RowMajor,                // SMEM layout | 
					
						
							|  |  |  |     typename MmaWarp::Policy | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using AccumulatorTileIterator = cutlass::epilogue::warp::TileIteratorSimtCanonical< | 
					
						
							|  |  |  |     typename MmaWarp::Shape, | 
					
						
							|  |  |  |     typename MmaWarp::ThreadMma, | 
					
						
							|  |  |  |     float,                             // ElementAccumulator | 
					
						
							|  |  |  |     layout::RowMajor,                  // SMEM layout | 
					
						
							|  |  |  |     typename MmaWarp::Policy | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using TensorRefA = typename MmaWarp::IteratorA::TensorRef; | 
					
						
							|  |  |  |   using TensorRefB = typename MmaWarp::IteratorB::TensorRef; | 
					
						
							|  |  |  |   using TensorRefC = typename AccumulatorTileIterator::TensorRef; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  |   CUTLASS_HOST_DEVICE | 
					
						
							|  |  |  |   GemmSimt() { } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTLASS_DEVICE | 
					
						
							|  |  |  |   void operator()( | 
					
						
							|  |  |  |     ElementScalar alpha,  | 
					
						
							|  |  |  |     TensorRefA ref_A,  | 
					
						
							|  |  |  |     TensorRefB ref_B,  | 
					
						
							|  |  |  |     ElementScalar beta, | 
					
						
							|  |  |  |     TensorRefC ref_C, | 
					
						
							|  |  |  |     TensorRefC ref_D, | 
					
						
							|  |  |  |     int lane_id) const { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Instantiate iterators pointing to slices of the A and B matrices in shared memory | 
					
						
							|  |  |  |     typename MmaWarp::IteratorA iter_A(ref_A, {Shape::kM, Shape::kK}, lane_id); | 
					
						
							|  |  |  |     typename MmaWarp::IteratorB iter_B(ref_B, {Shape::kK, Shape::kN}, lane_id); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Instantiate and clear accumulator tile holding the C matrix | 
					
						
							|  |  |  |     typename MmaWarp::FragmentC accum; | 
					
						
							|  |  |  |     accum.clear(); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |     // Instantiate the warp-level matrix multiply operator | 
					
						
							|  |  |  |     MmaWarp mma_op; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Instantiate fragments holding the slice of the matrix held by each warp | 
					
						
							|  |  |  |     typename MmaWarp::FragmentA frag_A[2]; | 
					
						
							|  |  |  |     typename MmaWarp::FragmentB frag_B[2]; | 
					
						
							|  |  |  |        | 
					
						
							|  |  |  |     // Load fragments from shared memory | 
					
						
							|  |  |  |     iter_A.load(frag_A[0]); | 
					
						
							|  |  |  |     iter_B.load(frag_B[0]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     ++iter_A; | 
					
						
							|  |  |  |     ++iter_B; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Load fragments from shared memory | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int k = 0; k < kKgroups; ++k) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Load fragments from shared memory | 
					
						
							|  |  |  |       iter_A.load(frag_A[(k + 1) % 2]); | 
					
						
							|  |  |  |       iter_B.load(frag_B[(k + 1) % 2]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       ++iter_A; | 
					
						
							|  |  |  |       ++iter_B; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute the matrix multiply | 
					
						
							|  |  |  |       mma_op(accum, frag_A[k % 2], frag_B[k % 2], accum); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |     // Instantiate iterators | 
					
						
							|  |  |  |     FragmentIterator accum_frag_it(accum); | 
					
						
							|  |  |  |     AccumulatorTileIterator source_tile_it(ref_C, {Shape::kM, Shape::kN}, lane_id); | 
					
						
							|  |  |  |     AccumulatorTileIterator dest_tile_it(ref_D, {Shape::kM, Shape::kN}, lane_id); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Define function objects for linear scaling operation | 
					
						
							|  |  |  |     cutlass::multiplies<typename FragmentIterator::Fragment> mul_source; | 
					
						
							|  |  |  |     cutlass::multiply_add<typename FragmentIterator::Fragment> mul_add_accumulator; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Iterate over the epilogue components | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_UNROLL | 
					
						
							|  |  |  |     for (int idx = 0; idx < FragmentIterator::kIterations; ++idx) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Define storage for slices of the accumulators | 
					
						
							|  |  |  |       typename FragmentIterator::Fragment accum_fragment; | 
					
						
							|  |  |  |       typename FragmentIterator::Fragment source_fragment; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Select a slice of accumulators from the accumulator tile | 
					
						
							|  |  |  |       accum_frag_it.load(accum_fragment); | 
					
						
							|  |  |  |       ++accum_frag_it; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Load a corresponding slice from Shared memory | 
					
						
							|  |  |  |       source_tile_it.load(source_fragment); | 
					
						
							|  |  |  |       ++source_tile_it; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Compute linear scaling - alpha * AB + beta * C | 
					
						
							|  |  |  |       source_fragment = mul_source(beta, source_fragment); | 
					
						
							|  |  |  |       accum_fragment = mul_add_accumulator(alpha, accum_fragment, source_fragment); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // Store the result to shared memory | 
					
						
							|  |  |  |       dest_tile_it.store(accum_fragment); | 
					
						
							|  |  |  |       ++dest_tile_it; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace warp | 
					
						
							|  |  |  | } // namespace gemm | 
					
						
							|  |  |  | } // namespace cutlass | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Sample kernel demonstrating a collective GEMM operation by a warp on arbitrary matrices held | 
					
						
							|  |  |  | // in Shared Memory. | 
					
						
							|  |  |  | __global__ void kernel( | 
					
						
							|  |  |  |   float *D_gmem,  | 
					
						
							|  |  |  |   float alpha,  | 
					
						
							|  |  |  |   float const *A_gmem,  | 
					
						
							|  |  |  |   float const *B_gmem,  | 
					
						
							|  |  |  |   float beta, | 
					
						
							|  |  |  |   float const *C_gmem) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Define several matrices in shared memory | 
					
						
							|  |  |  |   __shared__ float A[kM][kK]; | 
					
						
							|  |  |  |   __shared__ float B[kN][kK]; | 
					
						
							|  |  |  |   __shared__ float C[kM][kN]; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Copy data into SMEM | 
					
						
							|  |  |  |   if (threadIdx.x == 0) { | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |     for (int m = 0; m < kM; ++m) { | 
					
						
							|  |  |  |       for (int k = 0; k < kK; ++k) { | 
					
						
							|  |  |  |         A[m][k] = A_gmem[m * kK + k]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |     for (int n = 0; n < kN; ++n) { | 
					
						
							|  |  |  |       for (int k = 0; k < kK; ++k) { | 
					
						
							|  |  |  |         B[n][k] = B_gmem[n * kK + k]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |     for (int m = 0; m < kM; ++m) { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |       for (int n = 0; n < kN; ++n) { | 
					
						
							|  |  |  |         C[m][n] = C_gmem[m * kN + n]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   __syncthreads(); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  |   // Instantiate a warp-level matrix multiply operator given the fundamental instruction shape (8x8x4), | 
					
						
							|  |  |  |   // overall shape, data type of each operand, and layout of each operand. | 
					
						
							|  |  |  |   // | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GemmSimt = cutlass::gemm::warp::GemmSimt< | 
					
						
							|  |  |  |     cutlass::gemm::GemmShape<kM, kN, kK>, | 
					
						
							|  |  |  |     float,                             // Data type of A elements | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,          // Layout of A matrix | 
					
						
							|  |  |  |     float,                             // Data type of B elements | 
					
						
							|  |  |  |     cutlass::layout::ColumnMajor,       // Layout of B matrix | 
					
						
							|  |  |  |     float,                             // Data type of C elements | 
					
						
							|  |  |  |     cutlass::layout::RowMajor,          // Layout of C matrix | 
					
						
							|  |  |  |     float                              // Scalar type of alpha and beta | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Instantiate the GEMM operator | 
					
						
							|  |  |  |   GemmSimt gemm; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Execute the warp-level GEMM operation | 
					
						
							|  |  |  |   gemm( | 
					
						
							|  |  |  |     alpha,  | 
					
						
							|  |  |  |     {&A[0][0], kK}, | 
					
						
							|  |  |  |     {&B[0][0], kK}, | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     {&C[0][0], kN}, | 
					
						
							|  |  |  |     {&C[0][0], kN}, | 
					
						
							|  |  |  |     threadIdx.x); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   __syncthreads(); | 
					
						
							|  |  |  |    | 
					
						
							|  |  |  |   // Copy data into SMEM | 
					
						
							|  |  |  |   if (threadIdx.x == 0) { | 
					
						
							|  |  |  |     CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |     for (int m = 0; m < kM; ++m) { | 
					
						
							|  |  |  |       CUTLASS_PRAGMA_NO_UNROLL | 
					
						
							|  |  |  |       for (int n = 0; n < kN; ++n) { | 
					
						
							|  |  |  |         D_gmem[m * kN + n] = C[m][n]; | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | int main(int argc, const char *arg[]) {  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::HostTensor<float, cutlass::layout::RowMajor> A({kM, kK}); | 
					
						
							|  |  |  |   cutlass::HostTensor<float, cutlass::layout::ColumnMajor> B({kK, kN}); | 
					
						
							|  |  |  |   cutlass::HostTensor<float, cutlass::layout::RowMajor> C({kM, kN}); | 
					
						
							|  |  |  |   cutlass::HostTensor<float, cutlass::layout::RowMajor> D({kM, kN}); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   uint64_t seed = 2020; | 
					
						
							|  |  |  |   float max = 8; | 
					
						
							|  |  |  |   float min = -8; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   std::cout << "Simt canonical GEMM problem size = (" << cutlass::gemm::GemmShape<kM, kN, kK>() <<")" << std::endl; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::reference::host::TensorFillRandomUniform( | 
					
						
							|  |  |  |     A.host_view(), | 
					
						
							|  |  |  |     seed, | 
					
						
							|  |  |  |     max, | 
					
						
							|  |  |  |     min, | 
					
						
							|  |  |  |     0 | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::reference::host::TensorFillRandomUniform( | 
					
						
							|  |  |  |     B.host_view(), | 
					
						
							|  |  |  |     seed + 17, | 
					
						
							|  |  |  |     max, | 
					
						
							|  |  |  |     min, | 
					
						
							|  |  |  |     0 | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #if 0 // Debug: fill A sequentially and B as Identity matrix for debugging | 
					
						
							|  |  |  |   cutlass::reference::host::BlockFillSequential( | 
					
						
							|  |  |  |         A.host_view().data(), A.host_view().capacity()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::reference::host::TensorFillIdentity(B.host_view()); | 
					
						
							|  |  |  | #endif | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::reference::host::TensorFillRandomUniform( | 
					
						
							|  |  |  |     C.host_view(), | 
					
						
							|  |  |  |     seed + 31, | 
					
						
							|  |  |  |     max, | 
					
						
							|  |  |  |     min, | 
					
						
							|  |  |  |     0 | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   A.sync_device(); | 
					
						
							|  |  |  |   B.sync_device(); | 
					
						
							|  |  |  |   C.sync_device(); | 
					
						
							|  |  |  |   D.sync_device(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   dim3 grid(1, 1); | 
					
						
							|  |  |  |   dim3 block(32, 1, 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   float alpha = 1.0f; | 
					
						
							|  |  |  |   float beta = 0.0f; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   kernel<<< grid, block >>>( | 
					
						
							|  |  |  |     D.device_data(), | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     A.device_data(), | 
					
						
							|  |  |  |     B.device_data(), | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     C.device_data() | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cudaError_t result = cudaDeviceSynchronize(); | 
					
						
							|  |  |  |   if (result != cudaSuccess) { | 
					
						
							|  |  |  |     std::cerr << "Failed to synchronize device after kernel launch." << std::endl; | 
					
						
							|  |  |  |     return -1; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   D.sync_host(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Compute reference on host | 
					
						
							|  |  |  |   cutlass::HostTensor<float, cutlass::layout::RowMajor> D_ref({kM, kN}, false); | 
					
						
							|  |  |  |   cutlass::reference::host::TensorCopy(D_ref.host_view(), C.host_view()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   cutlass::reference::host::Gemm< | 
					
						
							|  |  |  |   float, cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |   float, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |   float, cutlass::layout::RowMajor,  | 
					
						
							|  |  |  |   float, float> reference_gemm; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   reference_gemm( | 
					
						
							|  |  |  |     {kM, kN, kK}, | 
					
						
							|  |  |  |     alpha, | 
					
						
							|  |  |  |     A.host_ref(), | 
					
						
							|  |  |  |     B.host_ref(), | 
					
						
							|  |  |  |     beta, | 
					
						
							|  |  |  |     D_ref.host_ref(), | 
					
						
							|  |  |  |     float() | 
					
						
							|  |  |  |   ); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Verify reference matches computed | 
					
						
							|  |  |  |   if (!cutlass::reference::host::TensorEquals( | 
					
						
							|  |  |  |     D.host_view(), | 
					
						
							|  |  |  |     D_ref.host_view())) { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::cerr  | 
					
						
							|  |  |  |       << "A =\n" << A.host_view()  | 
					
						
							|  |  |  |       << "\n\nB = \n" << B.host_view()  | 
					
						
							|  |  |  |       << "\n\nC = " << C.host_view()  | 
					
						
							|  |  |  |       << "\n\nRef =\n" << D_ref.host_view() | 
					
						
							|  |  |  |       << "\n\nD =\n" << D.host_view() << "\n\n"; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     std::cerr << "Error - device results mismatch host reference." << std::endl; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return -1; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   std::cout << "Passed" << std::endl; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   return 0;  | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | /////////////////////////////////////////////////////////////////////////////////////////////////// |