568 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			568 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| #include <cstdlib>
 | |
| #include <cstdio>
 | |
| #include <cassert>
 | |
| 
 | |
| #include <thrust/host_vector.h>
 | |
| #include <thrust/device_vector.h>
 | |
| 
 | |
| #include <cute/tensor.hpp>
 | |
| 
 | |
| #include "cutlass/util/print_error.hpp"
 | |
| #include "cutlass/util/GPU_Clock.hpp"
 | |
| #include "cutlass/util/helper_cuda.hpp"
 | |
| 
 | |
| template <class ProblemShape, class CtaTiler,
 | |
|           class TA, class AStride, class ASmemLayout, class TiledCopyA,
 | |
|           class TB, class BStride, class BSmemLayout, class TiledCopyB,
 | |
|           class TC, class CStride, class CSmemLayout, class TiledMma,
 | |
|           class Alpha, class Beta>
 | |
| __global__ static
 | |
| __launch_bounds__(decltype(size(TiledMma{}))::value)
 | |
| void
 | |
| gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler,
 | |
|             TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a,
 | |
|             TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b,
 | |
|             TC      * C, CStride dC, CSmemLayout          , TiledMma mma,
 | |
|             Alpha alpha, Beta beta)
 | |
| {
 | |
|   using namespace cute;
 | |
| 
 | |
|   // Preconditions
 | |
|   CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{});                   // (M, N, K)
 | |
|   CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{});                   // (BLK_M, BLK_N, BLK_K)
 | |
| 
 | |
|   CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma));                     // NumThreads
 | |
|   CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma));                     // NumThreads
 | |
| 
 | |
|   static_assert(is_static<ASmemLayout>::value);
 | |
|   static_assert(is_static<BSmemLayout>::value);
 | |
|   static_assert(is_static<CSmemLayout>::value);
 | |
| 
 | |
|   CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler));  // BLK_M
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<0>(cta_tiler));  // BLK_M
 | |
|   CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler));  // BLK_N
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler));  // BLK_K
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler));  // BLK_K
 | |
| 
 | |
|   CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA));         // dA strides for shape MK
 | |
|   CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB));         // dB strides for shape NK
 | |
|   CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC));         // dC strides for shape MN
 | |
| 
 | |
|   //
 | |
|   // Full and Tiled Tensors
 | |
|   //
 | |
| 
 | |
|   // Represent the full tensors
 | |
|   Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K)
 | |
|   Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K)
 | |
|   Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N)
 | |
| 
 | |
|   // Get the appropriate blocks for this thread block
 | |
|   auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _);              // (m,n,k)
 | |
|   Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{});  // (BLK_M,BLK_K,k)
 | |
|   Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{});  // (BLK_N,BLK_K,k)
 | |
|   Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{});  // (BLK_M,BLK_N)
 | |
| 
 | |
|   // Shared memory buffers
 | |
|   __shared__ TA smemA[cosize_v<ASmemLayout>];
 | |
|   __shared__ TB smemB[cosize_v<BSmemLayout>];
 | |
|   Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout);            // (BLK_M,BLK_K,PIPE)
 | |
|   Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout);            // (BLK_N,BLK_K,PIPE)
 | |
| 
 | |
|   //
 | |
|   // Partition the copying of A and B tiles across the threads
 | |
|   //
 | |
| 
 | |
|   ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x);
 | |
|   Tensor tAgA = thr_copy_a.partition_S(gA);                            // (CPY,CPY_M,CPY_K,k)
 | |
|   Tensor tAsA = thr_copy_a.partition_D(sA);                            // (CPY,CPY_M,CPY_K,PIPE)
 | |
| 
 | |
|   ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x);
 | |
|   Tensor tBgB = thr_copy_b.partition_S(gB);                            // (CPY,CPY_N,CPY_K,k)
 | |
|   Tensor tBsB = thr_copy_b.partition_D(sB);                            // (CPY,CPY_N,CPY_K,PIPE)
 | |
| 
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA));                // CPY_M
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA));                // CPY_K
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB));                // CPY_N
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB));                // CPY_K
 | |
| 
 | |
|   //
 | |
|   // PREFETCH
 | |
|   //
 | |
| 
 | |
|   auto K_PIPE_MAX = size<3>(tAsA);
 | |
| 
 | |
|   // Total count of tiles
 | |
|   int k_tile_count = size<3>(tAgA);
 | |
|   // Current tile index in gmem to read from
 | |
|   int k_tile_next = 0;
 | |
| 
 | |
|   // Start async loads for all pipes but the last
 | |
|   CUTE_UNROLL
 | |
|   for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) {
 | |
|     copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe));
 | |
|     copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe));
 | |
|     cp_async_fence();
 | |
|     --k_tile_count;
 | |
|     if (k_tile_count > 0) { ++k_tile_next; }
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Define A/B partitioning and C accumulators
 | |
|   //
 | |
| 
 | |
|   ThrMMA thr_mma = mma.get_slice(threadIdx.x);
 | |
|   Tensor tCsA = thr_mma.partition_A(sA);                               // (MMA,MMA_M,MMA_K,PIPE)
 | |
|   Tensor tCsB = thr_mma.partition_B(sB);                               // (MMA,MMA_N,MMA_K,PIPE)
 | |
|   Tensor tCgC = thr_mma.partition_C(gC);                               // (MMA,MMA_M,MMA_N)
 | |
| 
 | |
|   // Allocate registers for pipelining
 | |
|   Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0));                // (MMA,MMA_M,MMA_K)
 | |
|   Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0));                // (MMA,MMA_N,MMA_K)
 | |
|   // Allocate the accumulators -- same size as the projected data
 | |
|   Tensor tCrC = thr_mma.make_fragment_C(tCgC);                         // (MMA,MMA_M,MMA_N)
 | |
| 
 | |
|   CUTE_STATIC_ASSERT_V(  shape(tCrA) ==   shape(tCsA));                // (MMA,MMA_M,MMA_K)
 | |
|   CUTE_STATIC_ASSERT_V(  shape(tCrB) ==   shape(tCsB));                // (MMA,MMA_N,MMA_K)
 | |
|   CUTE_STATIC_ASSERT_V(  shape(tCrC) ==   shape(tCgC));                // (MMA,MMA_M,MMA_N)
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA));                // MMA_M
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB));                // MMA_N
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB));                // MMA_K
 | |
| 
 | |
|   // Clear the accumulators
 | |
|   clear(tCrC);
 | |
| 
 | |
| #if 0
 | |
|   if(thread0()) {
 | |
|     print("  mA : "); print(  mA); print("\n");
 | |
|     print("  gA : "); print(  gA); print("\n");
 | |
|     print("  sA : "); print(  sA); print("\n");
 | |
|     print("tAgA : "); print(tAgA); print("\n");
 | |
|     print("tAsA : "); print(tAsA); print("\n");
 | |
|   }
 | |
| #endif
 | |
| 
 | |
| #if 0
 | |
|   if(thread0()) {
 | |
|     print("  mB : "); print(  mB); print("\n");
 | |
|     print("  gB : "); print(  gB); print("\n");
 | |
|     print("  sB : "); print(  sB); print("\n");
 | |
|     print("tBgB : "); print(tBgB); print("\n");
 | |
|     print("tBsB : "); print(tBsB); print("\n");
 | |
|   }
 | |
| #endif
 | |
| 
 | |
| #if 0
 | |
|   if(thread0()) {
 | |
|     print("  mC : "); print(  mC); print("\n");
 | |
|     print("  gC : "); print(  gC); print("\n");
 | |
|     print("tCsA : "); print(tCsA); print("\n");
 | |
|     print("tCsB : "); print(tCsB); print("\n");
 | |
|     print("tCgC : "); print(tCgC); print("\n");
 | |
|     print("tCrA : "); print(tCrA); print("\n");
 | |
|     print("tCrB : "); print(tCrB); print("\n");
 | |
|     print("tCrC : "); print(tCrC); print("\n");
 | |
|   }
 | |
| #endif
 | |
| 
 | |
| #if 1
 | |
| 
 | |
|   // Current pipe index in smem to read from
 | |
|   int smem_pipe_read  = 0;
 | |
|   // Current pipe index in smem to write to
 | |
|   int smem_pipe_write = K_PIPE_MAX-1;
 | |
| 
 | |
|   // Pipe slice
 | |
|   Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read);
 | |
|   Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read);
 | |
| 
 | |
|   // Size of the register pipeline
 | |
|   auto K_BLOCK_MAX = size<2>(tCrA);
 | |
| 
 | |
|   // PREFETCH register pipeline
 | |
|   if (K_BLOCK_MAX > 1) {
 | |
|     // Wait until our first prefetched tile is loaded in
 | |
|     cp_async_wait<K_PIPE_MAX-2>();
 | |
|     __syncthreads();
 | |
| 
 | |
|     // Prefetch the first rmem from the first k-tile
 | |
|     copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{}));
 | |
|     copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{}));
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // PIPELINED MAIN LOOP
 | |
|   // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's cp.async instructions
 | |
|   //           and explicit pipelines in shared memory.
 | |
|   //   Data is read from global(k_tile_next) to shared(smem_pipe_write).
 | |
|   //   Data is read from shared(smem_pipe_read) to registers(k_block_next).
 | |
|   //   Data is computed on registers(b_block).
 | |
|   //
 | |
|   //   This allows all copies and compute to overlap:
 | |
|   //     Copy from gmem->smem can overlap with copies from smem->rmem and compute on rmem.
 | |
|   //     Copy from smem->rmem can overlap with compute on rmem.
 | |
|   //
 | |
| 
 | |
|   CUTE_NO_UNROLL
 | |
|   while (k_tile_count > -(K_PIPE_MAX-1))
 | |
|   {
 | |
|     CUTE_UNROLL
 | |
|     for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
 | |
|     {
 | |
|       if (k_block == K_BLOCK_MAX - 1)
 | |
|       {
 | |
|         // Slice the smem_pipe_read smem
 | |
|         tCsA_p = tCsA(_,_,_,smem_pipe_read);
 | |
|         tCsB_p = tCsB(_,_,_,smem_pipe_read);
 | |
| 
 | |
|         // Commit the smem for smem_pipe_read
 | |
|         cp_async_wait<K_PIPE_MAX-2>();
 | |
|         __syncthreads();
 | |
|       }
 | |
| 
 | |
|       // Load A, B shmem->regs for k_block+1
 | |
|       auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX;      // static
 | |
|       copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next));
 | |
|       copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next));
 | |
|       // Copy gmem to smem before computing gemm on each k-pipe
 | |
|       if (k_block == 0)
 | |
|       {
 | |
|         copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write));
 | |
|         copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write));
 | |
|         cp_async_fence();
 | |
| 
 | |
|         // Advance the gmem tile
 | |
|         --k_tile_count;
 | |
|         if (k_tile_count > 0) { ++k_tile_next; }
 | |
| 
 | |
|         // Advance the smem pipe
 | |
|         smem_pipe_write = smem_pipe_read;
 | |
|         ++smem_pipe_read;
 | |
|         smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read;
 | |
|       }
 | |
|       // Thread-level register gemm for k_block
 | |
|       gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
 | |
|     }
 | |
| 
 | |
|   }
 | |
| 
 | |
| #endif
 | |
| 
 | |
|   //
 | |
|   // Epilogue
 | |
|   //
 | |
| 
 | |
|   axpby(alpha, tCrC, beta, tCgC);
 | |
| }
 | |
| 
 | |
| // Setup params for a NT GEMM
 | |
| template <class TA, class TB, class TC,
 | |
|           class Alpha, class Beta>
 | |
| void
 | |
| gemm_nt(int m, int n, int k,
 | |
|         Alpha alpha,
 | |
|         TA const* A, int ldA,
 | |
|         TB const* B, int ldB,
 | |
|         Beta beta,
 | |
|         TC      * C, int ldC,
 | |
|         cudaStream_t stream = 0)
 | |
| {
 | |
|   using namespace cute;
 | |
| 
 | |
|   // Define shapes (dynamic)
 | |
|   auto M = int(m);
 | |
|   auto N = int(n);
 | |
|   auto K = int(k);
 | |
|   auto prob_shape = make_shape(M, N, K);                     // (M, N, K)
 | |
| 
 | |
|   // Define NT strides (mixed)
 | |
|   auto dA = make_stride(Int<1>{}, ldA);                      // (dM, dK)
 | |
|   auto dB = make_stride(Int<1>{}, ldB);                      // (dN, dK)
 | |
|   auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)
 | |
| 
 | |
|   // Define CTA tile sizes (static)
 | |
|   auto bM = Int<128>{};
 | |
|   auto bN = Int<128>{};
 | |
|   auto bK = Int<  8>{};
 | |
|   auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
 | |
|   auto bP = Int<3>{};  // Pipeline
 | |
| 
 | |
|   // Define the smem layouts (static)
 | |
|   auto sA = make_layout(make_shape(bM, bK, bP));             // (m,k,p) -> smem_idx; m-major
 | |
|   auto sB = make_layout(make_shape(bN, bK, bP));             // (n,k,p) -> smem_idx; n-major
 | |
|   auto sC = make_layout(make_shape(bM, bN));                 // (m,n) -> smem_idx; m-major
 | |
| 
 | |
|   // Define the thread layouts (static)
 | |
| 
 | |
|   TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TA>{},
 | |
|                                     Layout<Shape<_32,_8>>{}, // Thr layout 32x8 m-major
 | |
|                                     Layout<Shape< _4,_1>>{});// Val layout  4x1 m-major
 | |
|   TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, TB>{},
 | |
|                                     Layout<Shape<_32,_8>>{}, // Thr layout 32x8 n-major
 | |
|                                     Layout<Shape< _4,_1>>{});// Val layout  4x1 n-major
 | |
| 
 | |
|   TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
 | |
|                                  Layout<Shape<_16,_16,_1>>{});  // 16x16x1 TiledMMA
 | |
| 
 | |
| #if 0
 | |
|   print(copyA);
 | |
|   print(copyB);
 | |
|   print(mmaC);
 | |
| #endif
 | |
| 
 | |
| #if 0
 | |
|   print_latex(copyA);
 | |
|   print_latex(copyB);
 | |
|   print_latex(mmaC);
 | |
| #endif
 | |
| 
 | |
|   dim3 dimBlock(size(mmaC));
 | |
|   dim3 dimGrid(size(ceil_div(M, bM)),
 | |
|                size(ceil_div(N, bN)));
 | |
|   gemm_device<<<dimGrid, dimBlock, 0, stream>>>
 | |
|       (prob_shape, cta_tiler,
 | |
|        A, dA, sA, copyA,
 | |
|        B, dB, sB, copyB,
 | |
|        C, dC, sC, mmaC,
 | |
|        alpha, beta);
 | |
| }
 | |
| 
 | |
| // Setup params for a NT GEMM
 | |
| template <class TA, class TB, class TC,
 | |
|           class Alpha, class Beta>
 | |
| void
 | |
| gemm_tn(int m, int n, int k,
 | |
|         Alpha alpha,
 | |
|         TA const* A, int ldA,
 | |
|         TB const* B, int ldB,
 | |
|         Beta beta,
 | |
|         TC      * C, int ldC,
 | |
|         cudaStream_t stream = 0)
 | |
| {
 | |
|   using namespace cute;
 | |
| 
 | |
|   // Define shapes (dynamic)
 | |
|   auto M = int(m);
 | |
|   auto N = int(n);
 | |
|   auto K = int(k);
 | |
|   auto prob_shape = make_shape(M, N, K);                     // (M, N, K)
 | |
| 
 | |
|   // Define TN strides (mixed)
 | |
|   auto dA = make_stride(ldA, Int<1>{});                      // (dM, dK)
 | |
|   auto dB = make_stride(ldB, Int<1>{});                      // (dN, dK)
 | |
|   auto dC = make_stride(Int<1>{}, ldC);                      // (dM, dN)
 | |
| 
 | |
|   // Define CTA tile sizes (static)
 | |
|   auto bM = Int<128>{};
 | |
|   auto bN = Int<128>{};
 | |
|   auto bK = Int<  8>{};
 | |
|   auto cta_tiler = make_shape(bM, bN, bK);                   // (BLK_M, BLK_N, BLK_K)
 | |
|   auto bP = Int<3>{};  // Pipeline
 | |
| 
 | |
|   // Define the smem layouts (static)
 | |
|   auto sA_atom = make_layout(make_shape (      bM,          bK),
 | |
|                              make_stride(Int<1>{}, bM+Int<1>{}));   // (m,k) -> smem_idx; padded m-major
 | |
|   auto sB_atom = make_layout(make_shape (      bN,          bK),
 | |
|                              make_stride(Int<1>{}, bN+Int<1>{}));   // (n,k) -> smem_idx; padded n-major
 | |
|   auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP));
 | |
|   auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP));
 | |
|   auto sC = make_layout(make_shape(bM, bN));                        // (m,n) -> smem_idx
 | |
| 
 | |
|   // Define the thread layouts (static)
 | |
| 
 | |
|   TiledCopy copyA = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TA>, TA>{},
 | |
|                                     Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
 | |
|                                     Layout<Shape< _1,_1>>{});              // Val layout  1x1
 | |
|   TiledCopy copyB = make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<TB>, TB>{},
 | |
|                                     Layout<Shape<_32,_8>,Stride<_8,_1>>{}, // Thr layout 32x8 k-major
 | |
|                                     Layout<Shape< _1,_1>>{});              // Val layout  1x1
 | |
| 
 | |
|   TiledMMA mmaC = make_tiled_mma(UniversalFMA<TC,TA,TB>{},
 | |
|                                  Layout<Shape<_16,_16,_1>>{});  // 16x16x1 TiledMMA
 | |
| 
 | |
| #if 0
 | |
|   print(copyA);
 | |
|   print(copyB);
 | |
|   print(mmaC);
 | |
| #endif
 | |
| 
 | |
| #if 0
 | |
|   print_latex(copyA);
 | |
|   print_latex(copyB);
 | |
|   print_latex(mmaC);
 | |
| #endif
 | |
| 
 | |
|   dim3 dimBlock(size(mmaC));
 | |
|   dim3 dimGrid(size(ceil_div(M, bM)),
 | |
|                size(ceil_div(N, bN)));
 | |
|   gemm_device<<<dimGrid, dimBlock, 0, stream>>>
 | |
|       (prob_shape, cta_tiler,
 | |
|        A, dA, sA, copyA,
 | |
|        B, dB, sB, copyB,
 | |
|        C, dC, sC, mmaC,
 | |
|        alpha, beta);
 | |
| }
 | |
| 
 | |
| template <class TA, class TB, class TC,
 | |
|           class Alpha, class Beta>
 | |
| void
 | |
| gemm(char transA, char transB, int m, int n, int k,
 | |
|      Alpha alpha,
 | |
|      TA const* A, int ldA,
 | |
|      TB const* B, int ldB,
 | |
|      Beta beta,
 | |
|      TC      * C, int ldC,
 | |
|      cudaStream_t stream = 0)
 | |
| {
 | |
|   if (transA == 'N' && transB == 'T') {
 | |
|     return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
 | |
|   } else
 | |
|   if (transA == 'T' && transB == 'N') {
 | |
|     return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream);
 | |
|   }
 | |
|   assert(false && "Not implemented");
 | |
| }
 | |
| 
 | |
| 
 | |
| int main(int argc, char** argv)
 | |
| {
 | |
|   cudaDeviceProp props;
 | |
|   cudaError_t error = cudaGetDeviceProperties(&props, 0);
 | |
|   if (error != cudaSuccess) {
 | |
|     std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
 | |
|     return -1;
 | |
|   }
 | |
| 
 | |
|   if (props.major < 8) {
 | |
|     std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl;
 | |
|     // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits.
 | |
|     return 0;
 | |
|   }
 | |
| 
 | |
|   int m = 5120;
 | |
|   if (argc >= 2)
 | |
|     sscanf(argv[1], "%d", &m);
 | |
| 
 | |
|   int n = 5120;
 | |
|   if (argc >= 3)
 | |
|     sscanf(argv[2], "%d", &n);
 | |
| 
 | |
|   int k = 4096;
 | |
|   if (argc >= 4)
 | |
|     sscanf(argv[3], "%d", &k);
 | |
| 
 | |
|   char transA = 'N';
 | |
|   if (argc >= 5)
 | |
|     sscanf(argv[4], "%c", &transA);
 | |
| 
 | |
|   char transB = 'T';
 | |
|   if (argc >= 6)
 | |
|     sscanf(argv[5], "%c", &transB);
 | |
| 
 | |
|   using TA = float;
 | |
|   using TB = float;
 | |
|   using TC = float;
 | |
|   using TI = float;
 | |
| 
 | |
|   TI alpha = 1.0;
 | |
|   TI beta  = 0.0;
 | |
| 
 | |
|   std::cout << "M = " << m << std::endl;
 | |
|   std::cout << "N = " << n << std::endl;
 | |
|   std::cout << "K = " << k << std::endl;
 | |
|   std::cout << "C = A^" << transA << " B^" << transB << std::endl;
 | |
| 
 | |
|   thrust::host_vector<TA> h_A(m*k);
 | |
|   thrust::host_vector<TB> h_B(n*k);
 | |
|   thrust::host_vector<TC> h_C(m*n);
 | |
| 
 | |
|   for (int j = 0; j < m*k; ++j) h_A[j] = static_cast<TA>( 2*(rand() / double(RAND_MAX)) - 1 );
 | |
|   for (int j = 0; j < n*k; ++j) h_B[j] = static_cast<TB>( 2*(rand() / double(RAND_MAX)) - 1 );
 | |
|   for (int j = 0; j < m*n; ++j) h_C[j] = static_cast<TC>(-1);
 | |
| 
 | |
|   thrust::device_vector<TA> d_A = h_A;
 | |
|   thrust::device_vector<TB> d_B = h_B;
 | |
|   thrust::device_vector<TC> d_C = h_C;
 | |
| 
 | |
|   double gflops = (2.0*m*n*k) * 1e-9;
 | |
| 
 | |
|   const int timing_iterations = 100;
 | |
|   GPU_Clock timer;
 | |
| 
 | |
|   int ldA = 0, ldB = 0, ldC = m;
 | |
| 
 | |
|   if (transA == 'N') {
 | |
|     ldA = m;
 | |
|   } else if (transA == 'T') {
 | |
|     ldA = k;
 | |
|   } else {
 | |
|     assert(false);
 | |
|   }
 | |
| 
 | |
|   if (transB == 'N') {
 | |
|     ldB = k;
 | |
|   } else if (transB == 'T') {
 | |
|     ldB = n;
 | |
|   } else {
 | |
|     assert(false);
 | |
|   }
 | |
| 
 | |
|   // Run once
 | |
|   d_C = h_C;
 | |
|   gemm(transA, transB, m, n, k,
 | |
|        alpha,
 | |
|        d_A.data().get(), ldA,
 | |
|        d_B.data().get(), ldB,
 | |
|        beta,
 | |
|        d_C.data().get(), ldC);
 | |
|   CUTE_CHECK_LAST();
 | |
|   thrust::host_vector<TC> cute_result = d_C;
 | |
| 
 | |
|   // Timing iterations
 | |
|   timer.start();
 | |
|   for (int i = 0; i < timing_iterations; ++i) {
 | |
|     gemm(transA, transB, m, n, k,
 | |
|          alpha,
 | |
|          d_A.data().get(), ldA,
 | |
|          d_B.data().get(), ldB,
 | |
|          beta,
 | |
|          d_C.data().get(), ldC);
 | |
|   }
 | |
|   double cute_time = timer.seconds() / timing_iterations;
 | |
|   CUTE_CHECK_LAST();
 | |
|   printf("CUTE_GEMM:     [%6.1f]GFlop/s  (%6.4f)ms\n", gflops / cute_time, cute_time*1000);
 | |
| 
 | |
|   return 0;
 | |
| }
 | 
