diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 5adcb1cb..c5f148d4 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -254,17 +254,21 @@ template < class ElementC_, class GmemLayoutTagC_, int AlignmentC, - class ElementD, + class ElementD_, class GmemLayoutTagD, int AlignmentD, class FusionOpOrCallbacks, class DispatchPolicy > struct Sm90TmaBuilderImpl { + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + // Passing void C disables source load + smem allocation using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; - + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; @@ -292,7 +296,7 @@ struct Sm90TmaBuilderImpl { EpilogueTile_MN, ElementC_, // Need to pass void through to expose via GemmUniversal GmemStrideTypeC, - ElementD, + ElementD_, GmemStrideTypeD, FusionCallbacks, CopyOpG2S, @@ -474,7 +478,7 @@ template < class ElementC, class GmemLayoutTagC, int AlignmentC, - class ElementD, + class ElementD_, class GmemLayoutTagD, int AlignmentD, class Schedule, @@ -491,7 +495,7 @@ struct CollectiveBuilder< ElementC, GmemLayoutTagC, AlignmentC, - ElementD, + ElementD_, GmemLayoutTagD, AlignmentD, Schedule, @@ -499,6 +503,8 @@ struct CollectiveBuilder< cute::enable_if_t || cute::is_same_v >> { private: + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override()); using DispatchPolicy = @@ -514,7 +520,7 @@ public: ElementC, GmemLayoutTagC, AlignmentC, - ElementD, + ElementD_, GmemLayoutTagD, AlignmentD, FusionOperation, diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index 301f85e6..1d904b05 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -121,11 +121,14 @@ public: static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); private: - using SmemElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using SmemElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using SmemElementC = cute::conditional_t; // prevents void ref breakages constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; - constexpr static bool ReuseSmemC = ReuseSmemC_; - constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; constexpr static bool is_m_major_C = detail::is_m_major(); constexpr static bool is_m_major_D = detail::is_m_major(); @@ -139,23 +142,33 @@ private: make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), cute::conditional_t, Step<_1,_2,_3>>{} )); - constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t, + EmptyType>; + using SmemDStorage = cute::conditional_t, + EmptyType>; - struct TensorStorageWithC { - alignas(SmemAlignmentC) array_aligned smem_C; - alignas(SmemAlignmentD) array_aligned smem_D; + struct TensorStorageImpl: cute::tuple { + using Base = cute::tuple; - using FusionStorage = typename FusionCallbacks::SharedStorage; - FusionStorage thread; - }; + constexpr decltype(auto) + smem_C() { + return cute::get<0>(static_cast(*this)); + } - struct TensorStorageWithoutC { - alignas(SmemAlignmentD) array_aligned smem_D; + constexpr decltype(auto) + smem_D() { + return cute::get<1>(static_cast(*this)); + } using FusionStorage = typename FusionCallbacks::SharedStorage; FusionStorage thread; @@ -175,8 +188,7 @@ public: using StorePipelineState = cutlass::PipelineState; struct SharedStorage { - using TensorStorage = - cute::conditional_t; + using TensorStorage = TensorStorageImpl; TensorStorage tensors; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -203,7 +215,7 @@ public: SmemLayoutC{}(_,_,0))); using TMA_D = decltype(make_tma_copy( CopyOpS2G{}, - make_tensor(make_gmem_ptr(static_cast(nullptr)), + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideD{}, int32_t(0)), StrideD{}), SmemLayoutD{}(_,_,0))); @@ -233,16 +245,16 @@ public: ; typename Params::TMA_C tma_load_c; - if constexpr (not cute::is_void_v) { + if constexpr (is_source_supported) { Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC)); tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0)); } - - Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); - typename Params::TMA_D tma_store_d = make_tma_copy( - CopyOpS2G{}, - tensor_d, - SmemLayoutD{}(_,_,0)); + + typename Params::TMA_D tma_store_d; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD)); + tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0)); + } return { FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), @@ -272,8 +284,11 @@ public: auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; - bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + } if constexpr (not cute::is_void_v) { constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; @@ -309,8 +324,12 @@ public: CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& epilogue_params) { - cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); - cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } } CUTLASS_HOST_DEVICE @@ -365,9 +384,14 @@ public: Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) // Apply epilogue subtile, get matching smem tensor - SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); - if constexpr (not ReuseSmemC and is_source_supported) { - ptr_sC = shared_tensors.smem_C.data(); + SmemElementC* ptr_sC = nullptr; + + if constexpr (is_source_supported) { + if constexpr (ReuseSmemC) { + ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); + } else { + ptr_sC = shared_tensors.smem_C().data(); + } } Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) @@ -499,11 +523,20 @@ public: Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) // Construct the corresponding pipelined smem tensors - SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); - if constexpr (not ReuseSmemC and is_source_supported) { - ptr_sC = shared_tensors.smem_C.data(); + SmemElementC* ptr_sC = nullptr; + if constexpr (is_source_supported) { + if constexpr (ReuseSmemC) { + ptr_sC = reinterpret_cast(shared_tensors.smem_D().data()); + } else { + ptr_sC = shared_tensors.smem_C().data(); + } } - ElementD* ptr_sD = shared_tensors.smem_D.data(); + + SmemElementD* ptr_sD = nullptr; + if constexpr (is_destination_supported) { + ptr_sD = shared_tensors.smem_D().data(); + } + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) Tensor sD_epi = cute::as_position_independent_swizzle_tensor( @@ -514,19 +547,19 @@ public: TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); // (t)hread-partition for (r)egister to (s)mem copy (tRS_) - TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) // Allocate D registers Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); - Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) // Vectorized fragment view constexpr int FragmentSize = DispatchPolicy::FragmentSize; Tensor tRS_rAcc_frg = recast>(tRS_rAcc); - Tensor tRS_rD_frg = recast>(tRS_rD); + Tensor tRS_rD_frg = recast>(tRS_rD); CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); // (t)hread-partition for (s)mem to (r)egister copy (tSR_) @@ -653,7 +686,9 @@ public: } // Copy tile from register to smem - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } // Post visit, pre async fence callback entry point constexpr bool issue_smem_store = true; // No smem store predication @@ -662,8 +697,10 @@ public: // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } } // Post async fence, pre TMA commit callback entry point diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 27aed592..84ca9e35 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -1247,6 +1247,33 @@ struct FusionCallbacks< }; ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { +template > +struct get_element_aux { + using type = void; +}; + +template +struct get_element_aux> { + using type = typename FusionOpOrCallbacks::ElementAux; +}; + +template +struct get_element_aux, cute::void_t<>> { + using type = typename get_element_aux::type; +}; + +template +struct get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename FusionCallbacks::Operation; + public: + using type = typename get_element_aux::type; +}; +} + +template +using get_element_aux_t = typename detail::get_element_aux::type; } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 8d856e5c..dbaaaa73 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -68,6 +68,7 @@ template < bool EnableNullptr = true // Noop on nullptr params > struct Sm90AuxStore { + using ElementAux = Element; static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 9eae6e19..e0f5d449 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -327,6 +327,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu ) cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu new file mode 100644 index 00000000..2b542f2b --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu @@ -0,0 +1,688 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 with cooperative EVT epilogue + D = alpha * acc + beta * c + aux_load +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#define CUTLASS_ARCH_MMA_SM90_SUPPORTED + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +namespace test::gemm::device { +template +static constexpr auto select_evt_d() { + using namespace cutlass::epilogue::fusion; + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using BinaryCompute0 = Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + >; + if constexpr (IsCNeed) { + using EVT_D = Sm90EVT, + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + BinaryCompute0>; + return *(EVT_D *)(nullptr); + } else { + return *(BinaryCompute0 *)(nullptr); + } +} + +template +bool testEVTAuxStoreWithoutD() { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (std::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + using GemmKernel = typename Gemm::GemmKernel; + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementD = typename Gemm::ElementD; + constexpr bool has_c = not cute::is_void_v; + cutlass::DeviceAllocation A_block; + cutlass::DeviceAllocation B_block; + cutlass::DeviceAllocation> C_block; + cutlass::DeviceAllocation D_block; + cutlass::DeviceAllocation aux_store_D_block; + cutlass::DeviceAllocation workspace; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + int l = 1; + problem_size = ProblemShapeType{m, n, k, l}; + + // Run Base Gemm to get reference D + A_block.reset(m * k); + B_block.reset(k * n); + C_block.reset(m * n); + D_block.reset(m * n); + aux_store_D_block.reset(m * n); + Gemm gemm_op_base; + + auto stride_A = cutlass::make_cute_packed_stride( + typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{})); + auto stride_B = cutlass::make_cute_packed_stride( + typename GemmKernel::StrideB{}, cute::make_shape(n, k, cute::Int<1>{})); + auto stride_C = cutlass::make_cute_packed_stride( + typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{})); + auto stride_D = cutlass::make_cute_packed_stride( + typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{})); + + auto arguments_base = typename Gemm::Arguments { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + A_block.get(), stride_A, + B_block.get(), stride_B + }, + { // Epilogue arguments + {}, // thread + has_c ? C_block.get() : nullptr, stride_C, + D_block.get(), stride_D, + }, // Epilogue arguments end + /*hw_info=*/{}, + /*scheduler_args=*/{} + }; + + // check without D aux store + // set D to be void and use Sm90AuxStore to write to D + // and then the D is the same + GemmWithoutD gemm_op; + + auto arguments = typename GemmWithoutD::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + A_block.get(), stride_A, + B_block.get(), stride_B + }, + { // Epilogue arguments + {}, // thread + has_c ? C_block.get() : nullptr, stride_C, + nullptr, stride_D, + }, // Epilogue arguments end + /*hw_info=*/{}, + /*scheduler_args=*/{} + }; + + constexpr float beta [[maybe_unused]] = 1.0; + constexpr float alpha [[maybe_unused]] = 1.0; + + using ElementC = typename GemmWithoutD::ElementC; + + if constexpr (not has_c) { + arguments_base.epilogue.thread = { + // binary op : alpha * acc + {{alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }; + arguments.epilogue.thread = { + // unary op: aux store D + { + // binary op : alpha * acc + {{alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, + {aux_store_D_block.get(), stride_D} + }; + + } else { + arguments_base.epilogue.thread = { + // ternary op : beta * C + (alpha * acc) + {{beta}}, // leaf op+args : beta + {}, // op+args : C + { + // binary op : alpha * acc + {{alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; + arguments.epilogue.thread = { + // unary op: aux store D + { + // ternary op : beta * C + (alpha * acc) + {{beta}}, // leaf op+args : beta + {}, // op+args : C + { + // binary op : alpha * acc + {{alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, + {aux_store_D_block.get(), stride_D} + }; + } + + + cutlass::Status status; + cudaError_t result; + + status = gemm_op_base.can_implement(arguments_base); + EXPECT_EQ(status, cutlass::Status::kSuccess) << "Error gemm base not supported"; + size_t workspace_size_base = Gemm::get_workspace_size(arguments_base); + workspace.reset(workspace_size_base); + status = gemm_op_base.initialize(arguments_base, workspace.get()); + status = gemm_op_base.run(); + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << "Error at Base Kernel Sync."; + + size_t workspace_size = GemmWithoutD::get_workspace_size(arguments); + workspace.reset(workspace_size); + status = gemm_op.can_implement(arguments); + EXPECT_EQ(status, cutlass::Status::kSuccess); + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + + bool passed = cutlass::reference::device::BlockCompareEqual(aux_store_D_block.get(), D_block.get(), m * n); + if (!passed) { + return false; + } + } + } + } + return true; +} +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_VoidC_VoidD_AuxStoreF16_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = false; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_VoidC_VoidD_AuxStoreF16_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = false; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_VoidC_VoidD_AuxStoreF32_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = false; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_WithC_VoidD_AuxStoreF16_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = true; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_WithC_VoidD_AuxStoreF16_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = true; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_WithC_VoidD_AuxStoreF32_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using namespace cutlass::epilogue::fusion; + + constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr bool has_c = true; + + using EVT_D = decltype(test::gemm::device::select_evt_d()); + using AuxStore = Sm90AuxStore; + + constexpr auto select_kernel = [](auto has_c, auto has_d) { + using FusionCallbacks = + cute::conditional_t>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cute::conditional_t, LayoutC, 8, + cute::conditional_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue>; + + return *(GemmKernel *)(nullptr); + }; + + using GemmKernel = decltype(select_kernel(cute::C{}, cute::C{})); + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using GemmKernelWithoutD = decltype(select_kernel(cute::C{}, cute::C{})); + using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::testEVTAuxStoreWithoutD(); + + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)