 cc3c29a81a
			
		
	
	
		cc3c29a81a
		
			
		
	
	
	
	
		
			
			* v3.6 * update changelog * update readme * fix typo * fixing typos * hopper gemm with weight prefetch --------- Co-authored-by: yuzhai <yuzhai@nvidia.com> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
		
			
				
	
	
		
			513 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			513 lines
		
	
	
		
			22 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /***************************************************************************************************
 | |
|  * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | |
|  * SPDX-License-Identifier: BSD-3-Clause
 | |
|  *
 | |
|  * Redistribution and use in source and binary forms, with or without
 | |
|  * modification, are permitted provided that the following conditions are met:
 | |
|  *
 | |
|  * 1. Redistributions of source code must retain the above copyright notice, this
 | |
|  * list of conditions and the following disclaimer.
 | |
|  *
 | |
|  * 2. Redistributions in binary form must reproduce the above copyright notice,
 | |
|  * this list of conditions and the following disclaimer in the documentation
 | |
|  * and/or other materials provided with the distribution.
 | |
|  *
 | |
|  * 3. Neither the name of the copyright holder nor the names of its
 | |
|  * contributors may be used to endorse or promote products derived from
 | |
|  * this software without specific prior written permission.
 | |
|  *
 | |
|  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 | |
|  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 | |
|  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 | |
|  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 | |
|  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 | |
|  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 | |
|  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 | |
|  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 | |
|  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 | |
|  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 | |
|  *
 | |
|  **************************************************************************************************/
 | |
| #pragma once
 | |
| 
 | |
| #include <cute/config.hpp>
 | |
| #include <cute/util/type_traits.hpp>
 | |
| 
 | |
| #include <cute/atom/mma_atom.hpp>
 | |
| 
 | |
| #include <cute/algorithm/axpby.hpp>
 | |
| #include <cute/algorithm/functional.hpp>
 | |
| #include <cute/algorithm/gemm.hpp>
 | |
| 
 | |
| #include <cute/tensor_impl.hpp>
 | |
| 
 | |
| namespace cute
 | |
| {
 | |
| 
 | |
| //
 | |
| // Cooperative Shared-Memory GEMMs
 | |
| //
 | |
| 
 | |
| namespace detail {
 | |
| 
 | |
| // Predicated Cooperative GEMM
 | |
| template <class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp, class BLoadTransformOp,
 | |
|           class CLoadTransformOp, class CStoreTransformOp,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| cooperative_gemm_predication(ThrMMA<Args...> const& thr_mma,
 | |
|                              Alpha const& alpha,
 | |
|                              Tensor<TA, ALayout> sA,
 | |
|                              Tensor<TB, BLayout> sB,
 | |
|                              Beta  const& beta,
 | |
|                              Tensor<TC, CLayout> sC,
 | |
|                              ALoadTransformOp  const& sA_load_op,  // transforms A values before use in GEMM
 | |
|                              BLoadTransformOp  const& sB_load_op,  // transforms B values before use in GEMM
 | |
|                              CLoadTransformOp  const& sC_load_op,  // transforms C values before use in GEMM
 | |
|                              CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
 | |
| {
 | |
|   using TypeA = typename TA::value_type;
 | |
|   using TypeB = typename TB::value_type;
 | |
|   using TypeC = typename TC::value_type;
 | |
| 
 | |
|   //
 | |
|   // MMA Partitioning
 | |
|   //
 | |
| 
 | |
|   // Partition the sA, sB, and sC tiles across the threads for the MMA
 | |
|   Tensor tCsA = thr_mma.partition_A(sA);                            // (MMA,MMA_M,MMA_K)
 | |
|   Tensor tCsB = thr_mma.partition_B(sB);                            // (MMA,MMA_N,MMA_K)
 | |
|   Tensor tCsC = thr_mma.partition_C(sC);                            // (MMA,MMA_M,MMA_N)
 | |
| 
 | |
|   // Create register tensors for the MMA to operate on
 | |
|   Tensor tCrA = thr_mma.make_fragment_A(tCsA);                      // (MMA,MMA_M,MMA_K)
 | |
|   Tensor tCrB = thr_mma.make_fragment_B(tCsB);                      // (MMA,MMA_N,MMA_K)
 | |
|   Tensor tCrC = thr_mma.make_fragment_C(tCsC);                      // (MMA,MMA_M,MMA_N)
 | |
| 
 | |
| #if 0
 | |
|   if (thread0()) {
 | |
|     print("  sA: "); print(  sA); print("\n");
 | |
|     print("  sB: "); print(  sB); print("\n");
 | |
|     print("  sC: "); print(  sC); print("\n");
 | |
|     print(thr_mma);
 | |
|     print("tCsA: "); print(tCsA); print("\n");
 | |
|     print("tCsB: "); print(tCsB); print("\n");
 | |
|     print("tCsC: "); print(tCsC); print("\n");
 | |
|     print("tCrA: "); print(tCrA); print("\n");
 | |
|     print("tCrB: "); print(tCrB); print("\n");
 | |
|     print("tCrC: "); print(tCrC); print("\n");
 | |
|   }
 | |
| #endif
 | |
| 
 | |
|   //
 | |
|   // PREDICATION
 | |
|   //
 | |
| 
 | |
|   // Create coordinate tensors for the problem
 | |
|   Tensor cA = make_identity_tensor(shape(sA));                      // (M,K) -> (m,k)
 | |
|   Tensor cB = make_identity_tensor(shape(sB));                      // (N,K) -> (n,k)
 | |
| 
 | |
|   // Repeat partitioning with thr_mma
 | |
|   Tensor tCcA = thr_mma.partition_A(cA);                            // (MMA,MMA_M,MMA_K) -> (m,k)
 | |
|   Tensor tCcB = thr_mma.partition_B(cB);                            // (MMA,MMA_N,MMA_K) -> (n,k)
 | |
| 
 | |
|   // Allocate the preds for MMA- and MMA_MN-modes
 | |
|   Tensor tCpA = make_tensor<bool>(make_shape(size<0>(tCsA), size<1>(tCsA)));
 | |
|   Tensor tCpB = make_tensor<bool>(make_shape(size<0>(tCsB), size<1>(tCsB)));
 | |
| 
 | |
|   // Populate the predicates on M and N
 | |
|   CUTE_UNROLL
 | |
|   for (int i = 0; i < size(tCpA); ++i) {
 | |
|     tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA));
 | |
|   }
 | |
|   CUTE_UNROLL
 | |
|   for (int i = 0; i < size(tCpB); ++i) {
 | |
|     tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB));
 | |
|   }
 | |
| 
 | |
| #if 0
 | |
|   if (thread0()) {
 | |
|     print("  cA: "); print(  cA); print("\n");
 | |
|     print("  cB: "); print(  cB); print("\n");
 | |
|     print("tCcA: "); print(tCcA); print("\n");
 | |
|     print("tCcB: "); print(tCcB); print("\n");
 | |
|     print_tensor(tCpA);
 | |
|     print_tensor(tCpB);
 | |
|   }
 | |
| #endif
 | |
| 
 | |
|   //
 | |
|   // PREFETCH k_block = 0
 | |
|   //   Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
 | |
|   //   Assumes the MMA-tiling in K is trivial
 | |
|   //
 | |
| 
 | |
|   constexpr int K_BLOCK_MAX = size<2>(tCrA);
 | |
| 
 | |
|   CUTE_UNROLL
 | |
|   for (int m = 0; m < size<1>(tCrA); ++m) {     // Copy MMA_M
 | |
|     CUTE_UNROLL
 | |
|     for (int i = 0; i < size<0>(tCrA); ++i) {   // Copy MMA_I
 | |
|       tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{};
 | |
|     }
 | |
|   }
 | |
|   CUTE_UNROLL
 | |
|   for (int n = 0; n < size<1>(tCrB); ++n) {     // Copy MMA_N
 | |
|     CUTE_UNROLL
 | |
|     for (int i = 0; i < size<0>(tCrB); ++i) {   // Copy MMA_I
 | |
|       tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{};
 | |
|     }
 | |
|   }
 | |
|   //
 | |
|   // MAINLOOP
 | |
|   //
 | |
| 
 | |
|   // Clear accumulators
 | |
|   clear(tCrC);
 | |
| 
 | |
|   CUTE_UNROLL
 | |
|   for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
 | |
|   {
 | |
|     if (k_block < K_BLOCK_MAX-1)   // static-if not the last k_block
 | |
|     {
 | |
|       int k_next = k_block + 1;    // Load k_next block
 | |
| 
 | |
|       //   Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block
 | |
|       //   Assumes the MMA-tiling in K is trivial
 | |
| 
 | |
|       CUTE_UNROLL
 | |
|       for (int m = 0; m < size<1>(tCrA); ++m) {       // Copy MMA_M
 | |
|         CUTE_UNROLL
 | |
|         for (int i = 0; i < size<0>(tCrA); ++i) {     // Copy MMA_I
 | |
|           tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{};
 | |
|         }
 | |
|       }
 | |
|       CUTE_UNROLL
 | |
|       for (int n = 0; n < size<1>(tCrB); ++n) {       // Copy MMA_N
 | |
|         CUTE_UNROLL
 | |
|         for (int i = 0; i < size<0>(tCrB); ++i) {     // Copy MMA_I
 | |
|           tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{};
 | |
|         }
 | |
|       }
 | |
|     }
 | |
|     // GEMM on k_block in registers
 | |
|     gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Epilogue
 | |
|   //
 | |
| 
 | |
|   // Create coordinate tensors for the problem
 | |
|   Tensor cC   = make_identity_tensor(shape(sC));                     // (M,N) -> (m,n)
 | |
|   // Repeat partitioning with thr_mma
 | |
|   Tensor tCcC = thr_mma.partition_C(cC);                             // (MMA,MMA_M,MMA_N) -> (m,n)
 | |
| 
 | |
|   const bool isBetaZero = (beta == Beta{});
 | |
| 
 | |
|   // Custom axpby_if for now
 | |
|   CUTE_UNROLL
 | |
|   for (int i = 0; i < size(tCrC); ++i)
 | |
|   {
 | |
|     if (elem_less(tCcC(i), shape(sC)))
 | |
|     {
 | |
|       tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast<TypeC>(tCrC(i))
 | |
|                                        : alpha * static_cast<TypeC>(tCrC(i)) +
 | |
|                                           beta * static_cast<TypeC>(sC_load_op(tCsC(i))));
 | |
|     }
 | |
|   }
 | |
| }
 | |
| 
 | |
| // Slow fallback path
 | |
| template <class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp, class BLoadTransformOp,
 | |
|           class CLoadTransformOp, class CStoreTransformOp,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| cooperative_gemm_predication(uint32_t thread_idx,
 | |
|                              TiledMMA<Args...> const& tiled_mma,
 | |
|                              Alpha const& alpha,
 | |
|                              Tensor<TA, ALayout> sA,
 | |
|                              Tensor<TB, BLayout> sB,
 | |
|                              Beta  const& beta,
 | |
|                              Tensor<TC, CLayout> sC,
 | |
|                              ALoadTransformOp  const& sA_load_op,  // transforms A values before use in GEMM
 | |
|                              BLoadTransformOp  const& sB_load_op,  // transforms B values before use in GEMM
 | |
|                              CLoadTransformOp  const& sC_load_op,  // transforms C values before use in GEMM
 | |
|                              CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
 | |
| {
 | |
|   // ThrMMA
 | |
|   auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
 | |
|   cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op);
 | |
| }
 | |
| 
 | |
| // Unpredicated Cooperative GEMM
 | |
| template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
 | |
|           class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp, class BLoadTransformOp,
 | |
|           class CLoadTransformOp, class CStoreTransformOp,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| cooperative_gemm_no_predication(uint32_t thread_idx,
 | |
|                                 TiledMMA<Args...> const& tiled_mma,
 | |
|                                 Alpha const& alpha,
 | |
|                                 Tensor<TA, ALayout> sA,
 | |
|                                 Tensor<TB, BLayout> sB,
 | |
|                                 Beta  const& beta,
 | |
|                                 Tensor<TC, CLayout> sC,
 | |
|                                 ALoadTransformOp  const& sA_load_op,  // transforms A values before use in GEMM
 | |
|                                 BLoadTransformOp  const& sB_load_op,  // transforms B values before use in GEMM
 | |
|                                 CLoadTransformOp  const& sC_load_op,  // transforms C values before use in GEMM
 | |
|                                 CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C
 | |
| {
 | |
|   using TypeA = typename TA::value_type;
 | |
|   using TypeB = typename TB::value_type;
 | |
|   using TypeC = typename TC::value_type;
 | |
| 
 | |
|   // ThrMMA
 | |
|   auto thr_mma = tiled_mma.get_thread_slice(thread_idx);
 | |
| 
 | |
|   //
 | |
|   // MMA Partitioning
 | |
|   //
 | |
| 
 | |
|   Tensor tCsC = thr_mma.partition_C(sC);
 | |
|   // Create register tensors for the MMA to operate on
 | |
|   Tensor tCrA  = thr_mma.partition_fragment_A(sA);                    // (MMA,MMA_M,MMA_K)
 | |
|   Tensor tCrB  = thr_mma.partition_fragment_B(sB);                    // (MMA,MMA_N,MMA_K)
 | |
|   Tensor tCrC  = thr_mma.make_fragment_C(tCsC);                       // (MMA,MMA_M,MMA_N)
 | |
| 
 | |
|   using CopyOpAType = SmemCopyOpA;
 | |
|   using CopyOpBType = SmemCopyOpB;
 | |
| 
 | |
|   auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom<CopyOpAType, TypeA>{}, thr_mma);
 | |
|   auto smem_thr_copy_A   = smem_tiled_copy_A.get_thread_slice(thread_idx);
 | |
|   Tensor tCsA            = smem_thr_copy_A.partition_S(sA);
 | |
|   Tensor tCrA_copy_view  = smem_thr_copy_A.retile_D(tCrA);
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view));             // CPY_M
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view));             // CPY_K
 | |
| 
 | |
|   auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom<CopyOpBType, TypeB>{}, thr_mma);
 | |
|   auto smem_thr_copy_B   = smem_tiled_copy_B.get_thread_slice(thread_idx);
 | |
|   Tensor tCsB            = smem_thr_copy_B.partition_S(sB);
 | |
|   Tensor tCrB_copy_view  = smem_thr_copy_B.retile_D(tCrB);
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view));            // CPY_N
 | |
|   CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view));            // CPY_K
 | |
| 
 | |
| #if 0
 | |
|   if (thread0()) {
 | |
|     print("  sA: "); print(sA); print("\n");
 | |
|     print("  sB: "); print(sB); print("\n");
 | |
|     print("  sC: "); print(sC); print("\n");
 | |
|     print(thr_mma); print("\n");
 | |
|     print("tCsC: "); print(tCsC); print("\n");
 | |
|     print("tCrA: "); print(tCrA); print("\n");
 | |
|     print("tCrB: "); print(tCrB); print("\n");
 | |
|     print("tCrC: "); print(tCrC); print("\n");
 | |
|     print(smem_thr_copy_A); print("\n");
 | |
|     print("tCsA: "); print(tCsA); print("\n");
 | |
|     print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n");
 | |
|     print(smem_thr_copy_B); print("\n");
 | |
|     print("tCsB: "); print(tCsB); print("\n");
 | |
|     print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n");
 | |
|   }
 | |
| #endif
 | |
| 
 | |
|   //
 | |
|   // PREFETCH
 | |
|   //
 | |
| 
 | |
|   copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{}));
 | |
|   copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{}));
 | |
|   //
 | |
|   // MAINLOOP
 | |
|   //
 | |
| 
 | |
|   // Clear accumulators
 | |
|   clear(tCrC);
 | |
| 
 | |
|   constexpr int K_BLOCK_MAX = size<2>(tCrA);
 | |
| 
 | |
|   CUTE_UNROLL
 | |
|   for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block)
 | |
|   {
 | |
|     // static-if load the next k_block. No k-predication required on these loads.
 | |
|     if (k_block < K_BLOCK_MAX-1)
 | |
|     {
 | |
|       // Load the next k_block
 | |
|       int k_next = k_block + 1;       // statically unrolled
 | |
|       copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next));
 | |
|       copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next));
 | |
|     }
 | |
| 
 | |
|     // Transform A and B, relying on the compiler to remove in case of identity ops
 | |
|     cute::transform(tCrA(_,_,k_block), sA_load_op);
 | |
|     cute::transform(tCrB(_,_,k_block), sB_load_op);
 | |
| 
 | |
|     // GEMM on k_block in registers
 | |
|     gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
 | |
|   }
 | |
| 
 | |
|   //
 | |
|   // Epilogue
 | |
|   //
 | |
| 
 | |
|   auto isBetaZero = [&] () {
 | |
|     if constexpr (is_complex<Beta>::value) {
 | |
|       return beta.real() == Int<0>{} && beta.imag() == Int<0>{};
 | |
|     }
 | |
|     else {
 | |
|       return beta == Int<0>{};
 | |
|     }
 | |
|     CUTE_GCC_UNREACHABLE;
 | |
|   } ();
 | |
| 
 | |
|   using CopyOpCType = SmemCopyOpC;
 | |
|   Tensor tCrD = thr_mma.make_fragment_C(tCsC);
 | |
|   if(!isBetaZero) {
 | |
|     copy(CopyOpCType{}, tCsC, tCrD);
 | |
|     // Transform C on/after load
 | |
|     cute::transform(tCrD, sC_load_op);
 | |
|   }
 | |
|   // C = alpha * (A * B) + beta * C
 | |
|   axpby(alpha, tCrC, beta, tCrD);
 | |
|   // Transform C before/on store
 | |
|   cute::transform(tCrD, sC_store_op);
 | |
|   copy(CopyOpCType{}, tCrD, tCsC);
 | |
| }
 | |
| 
 | |
| } // end namespace detail
 | |
| 
 | |
| template <class SmemCopyOpA, class SmemCopyOpB, class SmemCopyOpC,
 | |
|           class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp = cute::identity, class BLoadTransformOp  = cute::identity,
 | |
|           class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| cooperative_gemm(uint32_t thread_idx,
 | |
|                  TiledMMA<Args...> const& tiled_mma,
 | |
|                  Alpha const& alpha,
 | |
|                  Tensor<TA, ALayout> sA,
 | |
|                  Tensor<TB, BLayout> sB,
 | |
|                  Beta  const& beta,
 | |
|                  Tensor<TC, CLayout> sC,
 | |
|                  ALoadTransformOp  const& sA_load_op  = {}, // transforms A values before use in GEMM
 | |
|                  BLoadTransformOp  const& sB_load_op  = {}, // transforms B values before use in GEMM
 | |
|                  CLoadTransformOp  const& sC_load_op  = {}, // transforms C values before use in GEMM
 | |
|                  CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
 | |
| {
 | |
|   CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC));  // AM == CM
 | |
|   CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC));  // BN == CN
 | |
|   CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB));  // AK == BK
 | |
| 
 | |
|   using TypeA = typename TA::value_type;
 | |
|   using TypeB = typename TB::value_type;
 | |
|   using TypeC = typename TC::value_type;
 | |
| 
 | |
|   static_assert(is_convertible_v<decay_t<invoke_result_t<ALoadTransformOp, TypeA>>, TypeA>,
 | |
|     "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type");
 | |
|   static_assert(is_convertible_v<decay_t<invoke_result_t<BLoadTransformOp, TypeB>>, TypeB>,
 | |
|     "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type");
 | |
|   static_assert(is_convertible_v<decay_t<invoke_result_t<CLoadTransformOp, TypeC>>, TypeC>,
 | |
|     "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
 | |
|   static_assert(is_convertible_v<decay_t<invoke_result_t<CStoreTransformOp, TypeC>>, TypeC>,
 | |
|     "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type");
 | |
| 
 | |
|   static constexpr bool compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)),
 | |
|                                                 tile_shape(TiledMMA<Args...>{}));
 | |
|   if constexpr (compat) {
 | |
|     detail::cooperative_gemm_no_predication<SmemCopyOpA, SmemCopyOpB, SmemCopyOpC>(
 | |
|         thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
 | |
|         sA_load_op, sB_load_op, sC_load_op, sC_store_op
 | |
|     );
 | |
|   } else {
 | |
|     detail::cooperative_gemm_predication(
 | |
|       thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
 | |
|       sA_load_op, sB_load_op, sC_load_op, sC_store_op
 | |
|     );
 | |
|   }
 | |
| }
 | |
| 
 | |
| template <class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp = cute::identity, class BLoadTransformOp  = cute::identity,
 | |
|           class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| cooperative_gemm(uint32_t thread_idx,
 | |
|                  TiledMMA<Args...> const& tiled_mma,
 | |
|                  Alpha const& alpha,
 | |
|                  Tensor<TA, ALayout> sA,
 | |
|                  Tensor<TB, BLayout> sB,
 | |
|                  Beta  const& beta,
 | |
|                  Tensor<TC, CLayout> sC,
 | |
|                  ALoadTransformOp  const& sA_load_op  = {}, // transforms A values before use in GEMM
 | |
|                  BLoadTransformOp  const& sB_load_op  = {}, // transforms B values before use in GEMM
 | |
|                  CLoadTransformOp  const& sC_load_op  = {}, // transforms C values before use in GEMM
 | |
|                  CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
 | |
| {
 | |
|   using CopyOpA = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TA::value_type>>;
 | |
|   using CopyOpB = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TB::value_type>>;
 | |
|   using CopyOpC = AutoVectorizingCopyWithAssumedAlignment<sizeof_bits_v<typename TC::value_type>>;
 | |
|   cooperative_gemm<CopyOpA, CopyOpB, CopyOpC>(
 | |
|       thread_idx, tiled_mma, alpha, sA, sB, beta, sC,
 | |
|       sA_load_op, sB_load_op, sC_load_op, sC_store_op
 | |
|   );
 | |
| }
 | |
| 
 | |
| // Legacy overload of cute::gemm for backwards-compatibility
 | |
| template <class... Args,
 | |
|           class Alpha, class TA, class ALayout, class TB, class BLayout,
 | |
|           class Beta,  class TC, class CLayout,
 | |
|           class ALoadTransformOp = cute::identity, class BLoadTransformOp  = cute::identity,
 | |
|           class CLoadTransformOp = cute::identity, class CStoreTransformOp = cute::identity,
 | |
|           __CUTE_REQUIRES(ALayout::rank == 2 && is_smem<TA>::value &&
 | |
|                           BLayout::rank == 2 && is_smem<TB>::value &&
 | |
|                           CLayout::rank == 2 && is_smem<TC>::value)>
 | |
| CUTE_HOST_DEVICE
 | |
| void
 | |
| gemm(ThrMMA<Args...> const& thr_mma,
 | |
|      Alpha const& alpha,
 | |
|      Tensor<TA, ALayout> sA,
 | |
|      Tensor<TB, BLayout> sB,
 | |
|      Beta  const& beta,
 | |
|      Tensor<TC, CLayout> sC,
 | |
|      ALoadTransformOp  const& sA_load_op  = {}, // transforms A values before use in GEMM
 | |
|      BLoadTransformOp  const& sB_load_op  = {}, // transforms B values before use in GEMM
 | |
|      CLoadTransformOp  const& sC_load_op  = {}, // transforms C values before use in GEMM
 | |
|      CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C
 | |
| {
 | |
|   // Goes directly to the slow path to avoid getting thread_idx from thr_mma
 | |
|   detail::cooperative_gemm_predication(
 | |
|     thr_mma, alpha, sA, sB, beta, sC,
 | |
|     sA_load_op, sB_load_op, sC_load_op, sC_store_op
 | |
|   );
 | |
| }
 | |
| 
 | |
| } // end namespace cute
 |