/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include #include namespace cute { // // Cooperative Shared-Memory GEMMs // namespace detail { // Predicated Cooperative GEMM template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void cooperative_gemm_predication(ThrMMA const& thr_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { using TypeA = typename TA::value_type; using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; // // MMA Partitioning // // Partition the sA, sB, and sC tiles across the threads for the MMA Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) #if 0 if (thread0()) { print(" sA: "); print( sA); print("\n"); print(" sB: "); print( sB); print("\n"); print(" sC: "); print( sC); print("\n"); print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); } #endif // // PREDICATION // // Create coordinate tensors for the problem Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k) Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k) // Repeat partitioning with thr_mma Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) // Allocate the preds for MMA- and MMA_MN-modes Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); // Populate the predicates on M and N CUTE_UNROLL for (int i = 0; i < size(tCpA); ++i) { tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); } CUTE_UNROLL for (int i = 0; i < size(tCpB); ++i) { tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); } #if 0 if (thread0()) { print(" cA: "); print( cA); print("\n"); print(" cB: "); print( cB); print("\n"); print("tCcA: "); print(tCcA); print("\n"); print("tCcB: "); print(tCcB); print("\n"); print_tensor(tCpA); print_tensor(tCpB); } #endif // // PREFETCH k_block = 0 // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block // Assumes the MMA-tiling in K is trivial // constexpr int K_BLOCK_MAX = size<2>(tCrA); CUTE_UNROLL for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; } } // // MAINLOOP // // Clear accumulators clear(tCrC); CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block { int k_next = k_block + 1; // Load k_next block // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block // Assumes the MMA-tiling in K is trivial CUTE_UNROLL for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; } } CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; } } } // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } // // Epilogue // // Create coordinate tensors for the problem Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) // Repeat partitioning with thr_mma Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) const bool isBetaZero = (beta == Beta{}); // Custom axpby_if for now CUTE_UNROLL for (int i = 0; i < size(tCrC); ++i) { if (elem_less(tCcC(i), shape(sC))) { tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast(tCrC(i)) : alpha * static_cast(tCrC(i)) + beta * static_cast(sC_load_op(tCsC(i)))); } } } // Slow fallback path template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void cooperative_gemm_predication(uint32_t thread_idx, TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { // ThrMMA auto thr_mma = tiled_mma.get_thread_slice(thread_idx); cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op); } // Unpredicated Cooperative GEMM template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void cooperative_gemm_no_predication(uint32_t thread_idx, TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { using TypeA = typename TA::value_type; using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; // ThrMMA auto thr_mma = tiled_mma.get_thread_slice(thread_idx); // // MMA Partitioning // Tensor tCsC = thr_mma.partition_C(sC); // Create register tensors for the MMA to operate on Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) using CopyOpAType = SmemCopyOpA; using CopyOpBType = SmemCopyOpB; auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCsA = smem_thr_copy_A.partition_S(sA); Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); Tensor tCsB = smem_thr_copy_B.partition_S(sB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K #if 0 if (thread0()) { print(" sA: "); print(sA); print("\n"); print(" sB: "); print(sB); print("\n"); print(" sC: "); print(sC); print("\n"); print(thr_mma); print("\n"); print("tCsC: "); print(tCsC); print("\n"); print("tCrA: "); print(tCrA); print("\n"); print("tCrB: "); print(tCrB); print("\n"); print("tCrC: "); print(tCrC); print("\n"); print(smem_thr_copy_A); print("\n"); print("tCsA: "); print(tCsA); print("\n"); print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n"); print(smem_thr_copy_B); print("\n"); print("tCsB: "); print(tCsB); print("\n"); print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n"); } #endif // // PREFETCH // copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); // // MAINLOOP // // Clear accumulators clear(tCrC); constexpr int K_BLOCK_MAX = size<2>(tCrA); CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { // static-if load the next k_block. No k-predication required on these loads. if (k_block < K_BLOCK_MAX-1) { // Load the next k_block int k_next = k_block + 1; // statically unrolled copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next)); copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next)); } // Transform A and B, relying on the compiler to remove in case of identity ops cute::transform(tCrA(_,_,k_block), sA_load_op); cute::transform(tCrB(_,_,k_block), sB_load_op); // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } // // Epilogue // auto isBetaZero = [&] () { if constexpr (is_complex::value) { return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; } else { return beta == Int<0>{}; } CUTE_GCC_UNREACHABLE; } (); using CopyOpCType = SmemCopyOpC; Tensor tCrD = thr_mma.make_fragment_C(tCsC); if(!isBetaZero) { copy(CopyOpCType{}, tCsC, tCrD); // Transform C on/after load cute::transform(tCrD, sC_load_op); } // C = alpha * (A * B) + beta * C axpby(alpha, tCrC, beta, tCrD); // Transform C before/on store cute::transform(tCrD, sC_store_op); copy(CopyOpCType{}, tCrD, tCsC); } } // end namespace detail template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK using TypeA = typename TA::value_type; using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; static_assert(is_convertible_v>, TypeA>, "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type"); static_assert(is_convertible_v>, TypeB>, "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type"); static_assert(is_convertible_v>, TypeC>, "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); static_assert(is_convertible_v>, TypeC>, "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), tile_shape(TiledMMA{})); if constexpr (compat) { detail::cooperative_gemm_no_predication( thread_idx, tiled_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op ); } else { detail::cooperative_gemm_predication( thread_idx, tiled_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op ); } } template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void cooperative_gemm(uint32_t thread_idx, TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { using CopyOpA = AutoVectorizingCopyWithAssumedAlignment>; using CopyOpB = AutoVectorizingCopyWithAssumedAlignment>; using CopyOpC = AutoVectorizingCopyWithAssumedAlignment>; cooperative_gemm( thread_idx, tiled_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op ); } // Legacy overload of cute::gemm for backwards-compatibility template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void gemm(ThrMMA const& thr_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, Tensor sC, ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { // Goes directly to the slow path to avoid getting thread_idx from thr_mma detail::cooperative_gemm_predication( thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op ); } } // end namespace cute