/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #include #include #include #include #include #include #include "cutlass/util/print_error.hpp" #include "cutlass/util/GPU_Clock.hpp" #include "cutlass/util/helper_cuda.hpp" template __global__ static __launch_bounds__(decltype(size(TiledMma{}))::value) void gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, TC * C, CStride dC, CSmemLayout , TiledMma mma, Alpha alpha, Beta beta) { using namespace cute; // Preconditions CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads static_assert(is_static::value); static_assert(is_static::value); static_assert(is_static::value); CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN // // Full and Tiled Tensors // // Represent the full tensors Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) // Get the appropriate blocks for this thread block auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) // Shared memory buffers __shared__ TA smemA[cosize_v]; __shared__ TB smemB[cosize_v]; Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE) Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE) // // Partition the copying of A and B tiles across the threads // ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K // // PREFETCH // auto K_PIPE_MAX = size<3>(tAsA); // Total count of tiles int k_tile_count = size<3>(tAgA); // Current tile index in gmem to read from int k_tile_next = 0; // Start async loads for all pipes but the last CUTE_UNROLL for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) { copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe)); copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe)); cp_async_fence(); --k_tile_count; if (k_tile_count > 0) { ++k_tile_next; } } // // Define A/B partitioning and C accumulators // ThrMMA thr_mma = mma.get_slice(threadIdx.x); Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) // Allocate registers for pipelining Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K) // Allocate the accumulators -- same size as the projected data Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K) CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K) CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N) CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K // Clear the accumulators clear(tCrC); #if 0 if(thread0()) { print(" mA : "); print( mA); print("\n"); print(" gA : "); print( gA); print("\n"); print(" sA : "); print( sA); print("\n"); print("tAgA : "); print(tAgA); print("\n"); print("tAsA : "); print(tAsA); print("\n"); } #endif #if 0 if(thread0()) { print(" mB : "); print( mB); print("\n"); print(" gB : "); print( gB); print("\n"); print(" sB : "); print( sB); print("\n"); print("tBgB : "); print(tBgB); print("\n"); print("tBsB : "); print(tBsB); print("\n"); } #endif #if 0 if(thread0()) { print(" mC : "); print( mC); print("\n"); print(" gC : "); print( gC); print("\n"); print("tCsA : "); print(tCsA); print("\n"); print("tCsB : "); print(tCsB); print("\n"); print("tCgC : "); print(tCgC); print("\n"); print("tCrA : "); print(tCrA); print("\n"); print("tCrB : "); print(tCrB); print("\n"); print("tCrC : "); print(tCrC); print("\n"); } #endif #if 1 // Current pipe index in smem to read from int smem_pipe_read = 0; // Current pipe index in smem to write to int smem_pipe_write = K_PIPE_MAX-1; // Pipe slice Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); // Size of the register pipeline auto K_BLOCK_MAX = size<2>(tCrA); // PREFETCH register pipeline if (K_BLOCK_MAX > 1) { // Wait until our first prefetched tile is loaded in cp_async_wait(); __syncthreads(); // Prefetch the first rmem from the first k-tile copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{})); copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{})); } // // PIPELINED MAIN LOOP // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's cp.async instructions // and explicit pipelines in shared memory. // Data is read from global(k_tile_next) to shared(smem_pipe_write). // Data is read from shared(smem_pipe_read) to registers(k_block_next). // Data is computed on registers(b_block). // // This allows all copies and compute to overlap: // Copy from gmem->smem can overlap with copies from smem->rmem and compute on rmem. // Copy from smem->rmem can overlap with compute on rmem. // CUTE_NO_UNROLL while (k_tile_count > -(K_PIPE_MAX-1)) { CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { if (k_block == K_BLOCK_MAX - 1) { // Slice the smem_pipe_read smem tCsA_p = tCsA(_,_,_,smem_pipe_read); tCsB_p = tCsB(_,_,_,smem_pipe_read); // Commit the smem for smem_pipe_read cp_async_wait(); __syncthreads(); } // Load A, B shmem->regs for k_block+1 auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next)); copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next)); // Copy gmem to smem before computing gemm on each k-pipe if (k_block == 0) { copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write)); copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write)); cp_async_fence(); // Advance the gmem tile --k_tile_count; if (k_tile_count > 0) { ++k_tile_next; } // Advance the smem pipe smem_pipe_write = smem_pipe_read; ++smem_pipe_read; smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read; } // Thread-level register gemm for k_block gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } } #endif // // Epilogue // axpby(alpha, tCrC, beta, tCgC); } // Setup params for a NT GEMM template void gemm_nt(int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { using namespace cute; // Define shapes (dynamic) auto M = int(m); auto N = int(n); auto K = int(k); auto prob_shape = make_shape(M, N, K); // (M, N, K) // Define NT strides (mixed) auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // Define CTA tile sizes (static) auto bM = Int<128>{}; auto bN = Int<128>{}; auto bK = Int< 8>{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) auto bP = Int<3>{}; // Pipeline // Define the smem layouts (static) auto sA = make_layout(make_shape(bM, bK, bP)); // (m,k,p) -> smem_idx; m-major auto sB = make_layout(make_shape(bN, bK, bP)); // (n,k,p) -> smem_idx; n-major auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major // Define the thread layouts (static) TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, Layout>{}, // Thr layout 32x8 m-major Layout>{});// Val layout 4x1 m-major TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, Layout>{}, // Thr layout 32x8 n-major Layout>{});// Val layout 4x1 n-major TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, Layout>{}); // 16x16x1 TiledMMA #if 0 print(copyA); print(copyB); print(mmaC); #endif #if 0 print_latex(copyA); print_latex(copyB); print_latex(mmaC); #endif dim3 dimBlock(size(mmaC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); gemm_device<<>> (prob_shape, cta_tiler, A, dA, sA, copyA, B, dB, sB, copyB, C, dC, sC, mmaC, alpha, beta); } // Setup params for a NT GEMM template void gemm_tn(int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { using namespace cute; // Define shapes (dynamic) auto M = int(m); auto N = int(n); auto K = int(k); auto prob_shape = make_shape(M, N, K); // (M, N, K) // Define TN strides (mixed) auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) // Define CTA tile sizes (static) auto bM = Int<128>{}; auto bN = Int<128>{}; auto bK = Int< 8>{}; auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) auto bP = Int<3>{}; // Pipeline // Define the smem layouts (static) auto sA_atom = make_layout(make_shape ( bM, bK), make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major auto sB_atom = make_layout(make_shape ( bN, bK), make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP)); auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP)); auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx // Define the thread layouts (static) TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major Layout>{}); // Val layout 1x1 TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major Layout>{}); // Val layout 1x1 TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, Layout>{}); // 16x16x1 TiledMMA #if 0 print(copyA); print(copyB); print(mmaC); #endif #if 0 print_latex(copyA); print_latex(copyB); print_latex(mmaC); #endif dim3 dimBlock(size(mmaC)); dim3 dimGrid(size(ceil_div(M, bM)), size(ceil_div(N, bN))); gemm_device<<>> (prob_shape, cta_tiler, A, dA, sA, copyA, B, dB, sB, copyB, C, dC, sC, mmaC, alpha, beta); } template void gemm(char transA, char transB, int m, int n, int k, Alpha alpha, TA const* A, int ldA, TB const* B, int ldB, Beta beta, TC * C, int ldC, cudaStream_t stream = 0) { if (transA == 'N' && transB == 'T') { return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); } else if (transA == 'T' && transB == 'N') { return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); } assert(false && "Not implemented"); } int main(int argc, char** argv) { cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; return -1; } if (props.major < 8) { std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl; // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. return 0; } int m = 5120; if (argc >= 2) sscanf(argv[1], "%d", &m); int n = 5120; if (argc >= 3) sscanf(argv[2], "%d", &n); int k = 4096; if (argc >= 4) sscanf(argv[3], "%d", &k); char transA = 'N'; if (argc >= 5) sscanf(argv[4], "%c", &transA); char transB = 'T'; if (argc >= 6) sscanf(argv[5], "%c", &transB); using TA = float; using TB = float; using TC = float; using TI = float; TI alpha = 1.0; TI beta = 0.0; std::cout << "M = " << m << std::endl; std::cout << "N = " << n << std::endl; std::cout << "K = " << k << std::endl; std::cout << "C = A^" << transA << " B^" << transB << std::endl; thrust::host_vector h_A(m*k); thrust::host_vector h_B(n*k); thrust::host_vector h_C(m*n); for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); thrust::device_vector d_A = h_A; thrust::device_vector d_B = h_B; thrust::device_vector d_C = h_C; double gflops = (2.0*m*n*k) * 1e-9; const int timing_iterations = 100; GPU_Clock timer; int ldA = 0, ldB = 0, ldC = m; if (transA == 'N') { ldA = m; } else if (transA == 'T') { ldA = k; } else { assert(false); } if (transB == 'N') { ldB = k; } else if (transB == 'T') { ldB = n; } else { assert(false); } // Run once d_C = h_C; gemm(transA, transB, m, n, k, alpha, d_A.data().get(), ldA, d_B.data().get(), ldB, beta, d_C.data().get(), ldC); CUTE_CHECK_LAST(); thrust::host_vector cute_result = d_C; // Timing iterations timer.start(); for (int i = 0; i < timing_iterations; ++i) { gemm(transA, transB, m, n, k, alpha, d_A.data().get(), ldA, d_B.data().get(), ldB, beta, d_C.data().get(), ldC); } double cute_time = timer.seconds() / timing_iterations; CUTE_CHECK_LAST(); printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); return 0; }