Support ElementD to be void for tma (#1153)
* Support void D with AuxStore * refine get_element_aux
This commit is contained in:
parent
751eb9a885
commit
362abbf274
@ -254,17 +254,21 @@ template <
|
||||
class ElementC_,
|
||||
class GmemLayoutTagC_,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class ElementD_,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class FusionOpOrCallbacks,
|
||||
class DispatchPolicy
|
||||
>
|
||||
struct Sm90TmaBuilderImpl {
|
||||
// Passing void D disables destination store + smem allocation
|
||||
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
||||
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;
|
||||
|
||||
// Passing void C disables source load + smem allocation
|
||||
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
||||
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
||||
|
||||
|
||||
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
||||
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
||||
|
||||
@ -292,7 +296,7 @@ struct Sm90TmaBuilderImpl {
|
||||
EpilogueTile_MN,
|
||||
ElementC_, // Need to pass void through to expose via GemmUniversal
|
||||
GmemStrideTypeC,
|
||||
ElementD,
|
||||
ElementD_,
|
||||
GmemStrideTypeD,
|
||||
FusionCallbacks,
|
||||
CopyOpG2S,
|
||||
@ -474,7 +478,7 @@ template <
|
||||
class ElementC,
|
||||
class GmemLayoutTagC,
|
||||
int AlignmentC,
|
||||
class ElementD,
|
||||
class ElementD_,
|
||||
class GmemLayoutTagD,
|
||||
int AlignmentD,
|
||||
class Schedule,
|
||||
@ -491,7 +495,7 @@ struct CollectiveBuilder<
|
||||
ElementC,
|
||||
GmemLayoutTagC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
ElementD_,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
Schedule,
|
||||
@ -499,6 +503,8 @@ struct CollectiveBuilder<
|
||||
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
|
||||
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> >> {
|
||||
private:
|
||||
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
||||
fusion::get_element_aux_t<FusionOperation>, ElementD_>;
|
||||
using EpilogueTile_MN =
|
||||
decltype(detail::sm90_compute_tile_shape_or_override<ElementD, EpilogueTileType, Schedule, TileShape_MNK>());
|
||||
using DispatchPolicy =
|
||||
@ -514,7 +520,7 @@ public:
|
||||
ElementC,
|
||||
GmemLayoutTagC,
|
||||
AlignmentC,
|
||||
ElementD,
|
||||
ElementD_,
|
||||
GmemLayoutTagD,
|
||||
AlignmentD,
|
||||
FusionOperation,
|
||||
|
@ -121,11 +121,14 @@ public:
|
||||
static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
|
||||
|
||||
private:
|
||||
using SmemElementC = cute::conditional_t<cute::is_void_v<ElementC>,ElementD,ElementC>; // prevents void ref breakages
|
||||
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
|
||||
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>;
|
||||
using SmemElementD = cute::conditional_t<not is_destination_supported,fusion::get_element_aux_t<FusionCallbacks>, ElementD>;
|
||||
static_assert(not cute::is_void_v<SmemElementD>, "SmemElementD is void");
|
||||
using SmemElementC = cute::conditional_t<not is_source_supported,SmemElementD,ElementC>; // prevents void ref breakages
|
||||
constexpr static int StagesC = StagesC_;
|
||||
constexpr static int StagesD = StagesD_;
|
||||
constexpr static bool ReuseSmemC = ReuseSmemC_;
|
||||
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
|
||||
constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported;
|
||||
|
||||
constexpr static bool is_m_major_C = detail::is_m_major<StrideC>();
|
||||
constexpr static bool is_m_major_D = detail::is_m_major<StrideD>();
|
||||
@ -139,23 +142,33 @@ private:
|
||||
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<ReuseSmemC ? StagesC : StagesD>{}),
|
||||
cute::conditional_t<is_m_major_D, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
||||
|
||||
constexpr static bool support_smem_reuse = is_source_supported && StagesD <= StagesC
|
||||
constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC
|
||||
&& cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{}));
|
||||
static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met");
|
||||
|
||||
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
||||
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
|
||||
|
||||
using EmptyType = cute::tuple<>;
|
||||
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
|
||||
array_aligned<SmemElementC, size(SmemLayoutC{}), SmemAlignmentC>,
|
||||
EmptyType>;
|
||||
using SmemDStorage = cute::conditional_t<is_destination_supported,
|
||||
array_aligned<SmemElementD, size(SmemLayoutD{}), SmemAlignmentD>,
|
||||
EmptyType>;
|
||||
|
||||
struct TensorStorageWithC {
|
||||
alignas(SmemAlignmentC) array_aligned<SmemElementC, size(SmemLayoutC{})> smem_C;
|
||||
alignas(SmemAlignmentD) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;
|
||||
struct TensorStorageImpl: cute::tuple<SmemCStorage, SmemDStorage> {
|
||||
using Base = cute::tuple<SmemCStorage, SmemDStorage>;
|
||||
|
||||
using FusionStorage = typename FusionCallbacks::SharedStorage;
|
||||
FusionStorage thread;
|
||||
};
|
||||
constexpr decltype(auto)
|
||||
smem_C() {
|
||||
return cute::get<0>(static_cast<Base &>(*this));
|
||||
}
|
||||
|
||||
struct TensorStorageWithoutC {
|
||||
alignas(SmemAlignmentD) array_aligned<ElementD, size(SmemLayoutD{})> smem_D;
|
||||
constexpr decltype(auto)
|
||||
smem_D() {
|
||||
return cute::get<1>(static_cast<Base &>(*this));
|
||||
}
|
||||
|
||||
using FusionStorage = typename FusionCallbacks::SharedStorage;
|
||||
FusionStorage thread;
|
||||
@ -175,8 +188,7 @@ public:
|
||||
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;
|
||||
|
||||
struct SharedStorage {
|
||||
using TensorStorage =
|
||||
cute::conditional_t<not is_source_supported or ReuseSmemC, TensorStorageWithoutC, TensorStorageWithC>;
|
||||
using TensorStorage = TensorStorageImpl;
|
||||
TensorStorage tensors;
|
||||
|
||||
using PipelineStorage = typename LoadPipeline::SharedStorage;
|
||||
@ -203,7 +215,7 @@ public:
|
||||
SmemLayoutC{}(_,_,0)));
|
||||
using TMA_D = decltype(make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
make_tensor(make_gmem_ptr(static_cast<ElementD const*>(nullptr)),
|
||||
make_tensor(make_gmem_ptr(static_cast<SmemElementD const*>(nullptr)),
|
||||
repeat_like(StrideD{}, int32_t(0)), StrideD{}),
|
||||
SmemLayoutD{}(_,_,0)));
|
||||
|
||||
@ -233,16 +245,16 @@ public:
|
||||
;
|
||||
|
||||
typename Params::TMA_C tma_load_c;
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
if constexpr (is_source_supported) {
|
||||
Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M_C,N,L), args.dC));
|
||||
tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0));
|
||||
}
|
||||
|
||||
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
|
||||
typename Params::TMA_D tma_store_d = make_tma_copy(
|
||||
CopyOpS2G{},
|
||||
tensor_d,
|
||||
SmemLayoutD{}(_,_,0));
|
||||
|
||||
typename Params::TMA_D tma_store_d;
|
||||
if constexpr (is_destination_supported) {
|
||||
Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M_D,N,L), args.dD));
|
||||
tma_store_d = make_tma_copy(CopyOpS2G{}, tensor_d, SmemLayoutD{}(_,_,0));
|
||||
}
|
||||
|
||||
return {
|
||||
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace),
|
||||
@ -272,8 +284,11 @@ public:
|
||||
auto problem_shape_MNKL = append<4>(problem_shape, 1);
|
||||
auto [M,N,K,L] = problem_shape_MNKL;
|
||||
|
||||
constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits<ElementD>::value;
|
||||
bool implementable = cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), StrideD{});
|
||||
bool implementable = true;
|
||||
if constexpr (is_destination_supported) {
|
||||
constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits<ElementD>::value;
|
||||
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), StrideD{});
|
||||
}
|
||||
|
||||
if constexpr (not cute::is_void_v<ElementC>) {
|
||||
constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits<ElementC>::value;
|
||||
@ -309,8 +324,12 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
static void
|
||||
prefetch_tma_descriptors(Params const& epilogue_params) {
|
||||
cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor());
|
||||
cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor());
|
||||
if constexpr (is_source_supported) {
|
||||
cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor());
|
||||
}
|
||||
if constexpr (is_destination_supported) {
|
||||
cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor());
|
||||
}
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
@ -365,9 +384,14 @@ public:
|
||||
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
|
||||
|
||||
// Apply epilogue subtile, get matching smem tensor
|
||||
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
|
||||
if constexpr (not ReuseSmemC and is_source_supported) {
|
||||
ptr_sC = shared_tensors.smem_C.data();
|
||||
SmemElementC* ptr_sC = nullptr;
|
||||
|
||||
if constexpr (is_source_supported) {
|
||||
if constexpr (ReuseSmemC) {
|
||||
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
|
||||
} else {
|
||||
ptr_sC = shared_tensors.smem_C().data();
|
||||
}
|
||||
}
|
||||
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
|
||||
@ -499,11 +523,20 @@ public:
|
||||
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
||||
|
||||
// Construct the corresponding pipelined smem tensors
|
||||
SmemElementC* ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D.data());
|
||||
if constexpr (not ReuseSmemC and is_source_supported) {
|
||||
ptr_sC = shared_tensors.smem_C.data();
|
||||
SmemElementC* ptr_sC = nullptr;
|
||||
if constexpr (is_source_supported) {
|
||||
if constexpr (ReuseSmemC) {
|
||||
ptr_sC = reinterpret_cast<SmemElementC*>(shared_tensors.smem_D().data());
|
||||
} else {
|
||||
ptr_sC = shared_tensors.smem_C().data();
|
||||
}
|
||||
}
|
||||
ElementD* ptr_sD = shared_tensors.smem_D.data();
|
||||
|
||||
SmemElementD* ptr_sD = nullptr;
|
||||
if constexpr (is_destination_supported) {
|
||||
ptr_sD = shared_tensors.smem_D().data();
|
||||
}
|
||||
|
||||
Tensor sC_epi = cute::as_position_independent_swizzle_tensor(
|
||||
make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
|
||||
Tensor sD_epi = cute::as_position_independent_swizzle_tensor(
|
||||
@ -514,19 +547,19 @@ public:
|
||||
TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma);
|
||||
|
||||
// (t)hread-partition for (r)egister to (s)mem copy (tRS_)
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom<CopyOpR2S,ElementD>{}, tiled_copy_C_atom);
|
||||
TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_copy_C_atom);
|
||||
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
|
||||
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
||||
Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D)
|
||||
|
||||
// Allocate D registers
|
||||
Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi))));
|
||||
Tensor tRS_rD = make_tensor<ElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
|
||||
Tensor tRS_rD = make_tensor<SmemElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
|
||||
|
||||
// Vectorized fragment view
|
||||
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
|
||||
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
|
||||
Tensor tRS_rD_frg = recast<Array<ElementD , FragmentSize>>(tRS_rD);
|
||||
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
|
||||
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
|
||||
|
||||
// (t)hread-partition for (s)mem to (r)egister copy (tSR_)
|
||||
@ -653,7 +686,9 @@ public:
|
||||
}
|
||||
|
||||
// Copy tile from register to smem
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
if constexpr (is_destination_supported) {
|
||||
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
||||
}
|
||||
|
||||
// Post visit, pre async fence callback entry point
|
||||
constexpr bool issue_smem_store = true; // No smem store predication
|
||||
@ -662,8 +697,10 @@ public:
|
||||
// Write the tile from smem to gmem with TMA
|
||||
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
||||
synchronize(); // ensure all threads have issued their async fence
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
|
||||
if constexpr (is_destination_supported) {
|
||||
if (issue_tma_store) {
|
||||
copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
|
||||
}
|
||||
}
|
||||
|
||||
// Post async fence, pre TMA commit callback entry point
|
||||
|
@ -1247,6 +1247,33 @@ struct FusionCallbacks<
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
namespace detail {
|
||||
template <class FusionOpOrCallbacks, class = cute::void_t<>>
|
||||
struct get_element_aux {
|
||||
using type = void;
|
||||
};
|
||||
|
||||
template <class FusionOpOrCallbacks>
|
||||
struct get_element_aux<FusionOpOrCallbacks, cute::void_t<typename FusionOpOrCallbacks::ElementAux>> {
|
||||
using type = typename FusionOpOrCallbacks::ElementAux;
|
||||
};
|
||||
|
||||
template <class NodeOp, class... ChildOps>
|
||||
struct get_element_aux<Sm90TreeVisitor<NodeOp, ChildOps...>, cute::void_t<>> {
|
||||
using type = typename get_element_aux<NodeOp>::type;
|
||||
};
|
||||
|
||||
template <class... Ts>
|
||||
struct get_element_aux<FusionCallbacks<Ts...>, cute::void_t<typename FusionCallbacks<Ts...>::Operation>> {
|
||||
private:
|
||||
using Operation = typename FusionCallbacks<Ts...>::Operation;
|
||||
public:
|
||||
using type = typename get_element_aux<Operation>::type;
|
||||
};
|
||||
}
|
||||
|
||||
template <class Callbacks>
|
||||
using get_element_aux_t = typename detail::get_element_aux<Callbacks>::type;
|
||||
|
||||
} // namespace cutlass::epilogue::fusion
|
||||
|
||||
|
@ -68,6 +68,7 @@ template <
|
||||
bool EnableNullptr = true // Noop on nullptr params
|
||||
>
|
||||
struct Sm90AuxStore {
|
||||
using ElementAux = Element;
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
|
||||
constexpr static bool is_m_major = epilogue::collective::detail::is_m_major<StrideMNL>();
|
||||
|
@ -327,6 +327,7 @@ cutlass_test_unit_add_executable(
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu
|
||||
sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_store.cu
|
||||
)
|
||||
cutlass_test_unit_add_executable(
|
||||
cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90
|
||||
|
@ -0,0 +1,688 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*! \file
|
||||
\brief Tests for Sm90 f16_f16_f16 with cooperative EVT epilogue
|
||||
D = alpha * acc + beta * c + aux_load
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
|
||||
#include "../../common/cutlass_unit_test.h"
|
||||
|
||||
#include "gemm_testbed_3x_evt.hpp"
|
||||
#include "sm90_evt_operations.hpp"
|
||||
|
||||
|
||||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace test::gemm::device {
|
||||
template <class ElementCompute, class ElementAccumulator, bool IsCNeed>
|
||||
static constexpr auto select_evt_d() {
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
using BinaryCompute0 = Sm90EVT<Sm90Compute<
|
||||
cutlass::multiplies,
|
||||
ElementCompute,
|
||||
ElementCompute,
|
||||
RoundStyle>, // alpha * acc
|
||||
Sm90ScalarBroadcast<ElementAccumulator>, // alpha
|
||||
Sm90AccFetch // acc
|
||||
>;
|
||||
if constexpr (IsCNeed) {
|
||||
using EVT_D = Sm90EVT<Sm90Compute<cutlass::homogeneous_multiply_add, ElementCompute, ElementCompute, RoundStyle>,
|
||||
Sm90ScalarBroadcast<ElementAccumulator>, // beta
|
||||
Sm90SrcFetch<ElementCompute>, // C
|
||||
BinaryCompute0>;
|
||||
return *(EVT_D *)(nullptr);
|
||||
} else {
|
||||
return *(BinaryCompute0 *)(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Gemm, class GemmWithoutD>
|
||||
bool testEVTAuxStoreWithoutD() {
|
||||
using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape;
|
||||
|
||||
int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB);
|
||||
std::vector<int> problem_size_m = {max_alignment, 512 - 3 * max_alignment};
|
||||
std::vector<int> problem_size_n = {max_alignment, 512 - 2 * max_alignment};
|
||||
|
||||
if constexpr (std::is_same_v<typename Gemm::GemmKernel::DispatchPolicy::Schedule,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpong>) {
|
||||
problem_size_m.push_back(768);
|
||||
problem_size_n.push_back(768);
|
||||
}
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages;
|
||||
constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{});
|
||||
|
||||
std::vector<int> problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment};
|
||||
using ElementA = typename Gemm::ElementA;
|
||||
using ElementB = typename Gemm::ElementB;
|
||||
using ElementC = typename Gemm::ElementC;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
constexpr bool has_c = not cute::is_void_v<ElementC>;
|
||||
cutlass::DeviceAllocation<ElementA> A_block;
|
||||
cutlass::DeviceAllocation<ElementB> B_block;
|
||||
cutlass::DeviceAllocation<cute::conditional_t<has_c, ElementC, ElementD>> C_block;
|
||||
cutlass::DeviceAllocation<ElementD> D_block;
|
||||
cutlass::DeviceAllocation<ElementD> aux_store_D_block;
|
||||
cutlass::DeviceAllocation<uint8_t> workspace;
|
||||
|
||||
for (int m : problem_size_m) {
|
||||
for (int n : problem_size_n) {
|
||||
for (int k : problem_size_k) {
|
||||
ProblemShapeType problem_size;
|
||||
int l = 1;
|
||||
problem_size = ProblemShapeType{m, n, k, l};
|
||||
|
||||
// Run Base Gemm to get reference D
|
||||
A_block.reset(m * k);
|
||||
B_block.reset(k * n);
|
||||
C_block.reset(m * n);
|
||||
D_block.reset(m * n);
|
||||
aux_store_D_block.reset(m * n);
|
||||
Gemm gemm_op_base;
|
||||
|
||||
auto stride_A = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideA{}, cute::make_shape(m, k, cute::Int<1>{}));
|
||||
auto stride_B = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideB{}, cute::make_shape(n, k, cute::Int<1>{}));
|
||||
auto stride_C = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideC{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
auto stride_D = cutlass::make_cute_packed_stride(
|
||||
typename GemmKernel::StrideD{}, cute::make_shape(m, n, cute::Int<1>{}));
|
||||
|
||||
auto arguments_base = typename Gemm::Arguments {
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{
|
||||
A_block.get(), stride_A,
|
||||
B_block.get(), stride_B
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{}, // thread
|
||||
has_c ? C_block.get() : nullptr, stride_C,
|
||||
D_block.get(), stride_D,
|
||||
}, // Epilogue arguments end
|
||||
/*hw_info=*/{},
|
||||
/*scheduler_args=*/{}
|
||||
};
|
||||
|
||||
// check without D aux store
|
||||
// set D to be void and use Sm90AuxStore to write to D
|
||||
// and then the D is the same
|
||||
GemmWithoutD gemm_op;
|
||||
|
||||
auto arguments = typename GemmWithoutD::Arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
problem_size,
|
||||
{
|
||||
A_block.get(), stride_A,
|
||||
B_block.get(), stride_B
|
||||
},
|
||||
{ // Epilogue arguments
|
||||
{}, // thread
|
||||
has_c ? C_block.get() : nullptr, stride_C,
|
||||
nullptr, stride_D,
|
||||
}, // Epilogue arguments end
|
||||
/*hw_info=*/{},
|
||||
/*scheduler_args=*/{}
|
||||
};
|
||||
|
||||
constexpr float beta [[maybe_unused]] = 1.0;
|
||||
constexpr float alpha [[maybe_unused]] = 1.0;
|
||||
|
||||
using ElementC = typename GemmWithoutD::ElementC;
|
||||
|
||||
if constexpr (not has_c) {
|
||||
arguments_base.epilogue.thread = {
|
||||
// binary op : alpha * acc
|
||||
{{alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
};
|
||||
arguments.epilogue.thread = {
|
||||
// unary op: aux store D
|
||||
{
|
||||
// binary op : alpha * acc
|
||||
{{alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
},
|
||||
{aux_store_D_block.get(), stride_D}
|
||||
};
|
||||
|
||||
} else {
|
||||
arguments_base.epilogue.thread = {
|
||||
// ternary op : beta * C + (alpha * acc)
|
||||
{{beta}}, // leaf op+args : beta
|
||||
{}, // op+args : C
|
||||
{
|
||||
// binary op : alpha * acc
|
||||
{{alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
};
|
||||
arguments.epilogue.thread = {
|
||||
// unary op: aux store D
|
||||
{
|
||||
// ternary op : beta * C + (alpha * acc)
|
||||
{{beta}}, // leaf op+args : beta
|
||||
{}, // op+args : C
|
||||
{
|
||||
// binary op : alpha * acc
|
||||
{{alpha}}, // leaf op+args : alpha
|
||||
{}, // leaf op+args : acc
|
||||
{} // binary args : multiplies
|
||||
}, // end binary op
|
||||
{} // ternary args : multiply_add
|
||||
},
|
||||
{aux_store_D_block.get(), stride_D}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
cutlass::Status status;
|
||||
cudaError_t result;
|
||||
|
||||
status = gemm_op_base.can_implement(arguments_base);
|
||||
EXPECT_EQ(status, cutlass::Status::kSuccess) << "Error gemm base not supported";
|
||||
size_t workspace_size_base = Gemm::get_workspace_size(arguments_base);
|
||||
workspace.reset(workspace_size_base);
|
||||
status = gemm_op_base.initialize(arguments_base, workspace.get());
|
||||
status = gemm_op_base.run();
|
||||
result = cudaDeviceSynchronize();
|
||||
EXPECT_EQ(result, cudaSuccess) << "Error at Base Kernel Sync.";
|
||||
|
||||
size_t workspace_size = GemmWithoutD::get_workspace_size(arguments);
|
||||
workspace.reset(workspace_size);
|
||||
status = gemm_op.can_implement(arguments);
|
||||
EXPECT_EQ(status, cutlass::Status::kSuccess);
|
||||
status = gemm_op.initialize(arguments, workspace.get());
|
||||
status = gemm_op.run();
|
||||
result = cudaDeviceSynchronize();
|
||||
EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync.";
|
||||
|
||||
bool passed = cutlass::reference::device::BlockCompareEqual(aux_store_D_block.get(), D_block.get(), m * n);
|
||||
if (!passed) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_VoidC_VoidD_AuxStoreF16_RowMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = false;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_VoidC_VoidD_AuxStoreF16_ColumnMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = false;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_VoidC_VoidD_AuxStoreF32_RowMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = false;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_WithC_VoidD_AuxStoreF16_RowMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = true;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_WithC_VoidD_AuxStoreF16_ColumnMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::ColumnMajor;
|
||||
using TileShape_MNK = Shape<_256,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = true;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
|
||||
TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_WithC_VoidD_AuxStoreF32_RowMajor) {
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using TileShape_MNK = Shape<_128,_128,_64>;
|
||||
using ClusterShape_MNK = Shape<_2,_2,_1>;
|
||||
|
||||
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
|
||||
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
|
||||
|
||||
using EpilogueDescriptor = cutlass::epilogue::collective::detail::EpilogueDescriptor<
|
||||
TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule
|
||||
>;
|
||||
using AuxStoreDescriptor = cutlass::epilogue::collective::detail::AuxStoreDescriptor<
|
||||
EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t
|
||||
>;
|
||||
|
||||
using namespace cutlass::epilogue::fusion;
|
||||
|
||||
constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
constexpr bool has_c = true;
|
||||
|
||||
using EVT_D = decltype(test::gemm::device::select_evt_d<cutlass::half_t, float, has_c>());
|
||||
using AuxStore = Sm90AuxStore<AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile,
|
||||
typename AuxStoreDescriptor::Element, RoundStyle,
|
||||
typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom,
|
||||
typename AuxStoreDescriptor::CopyOpR2S>;
|
||||
|
||||
constexpr auto select_kernel = [](auto has_c, auto has_d) {
|
||||
using FusionCallbacks =
|
||||
cute::conditional_t<decltype(has_d){}, EVT_D, Sm90EVT<AuxStore, EVT_D>>;
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
EpilogueTileType,
|
||||
float, float,
|
||||
cute::conditional_t<decltype(has_c){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
cute::conditional_t<decltype(has_d){}, cutlass::half_t, void>, LayoutC, 8,
|
||||
EpilogueSchedule,
|
||||
FusionCallbacks
|
||||
>::CollectiveOp;
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
||||
cutlass::half_t, LayoutA, 8,
|
||||
cutlass::half_t, LayoutB, 8,
|
||||
float,
|
||||
TileShape_MNK, ClusterShape_MNK,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<sizeof(typename CollectiveEpilogue::SharedStorage)>,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperative
|
||||
>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int,int,int,int>,
|
||||
CollectiveMainloop,
|
||||
CollectiveEpilogue>;
|
||||
|
||||
return *(GemmKernel *)(nullptr);
|
||||
};
|
||||
|
||||
using GemmKernel = decltype(select_kernel(cute::C<has_c>{}, cute::C<true>{}));
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using GemmKernelWithoutD = decltype(select_kernel(cute::C<has_c>{}, cute::C<false>{}));
|
||||
using GemmWithoutD = cutlass::gemm::device::GemmUniversalAdapter<GemmKernelWithoutD>;
|
||||
|
||||
bool passed = test::gemm::device::testEVTAuxStoreWithoutD<Gemm, GemmWithoutD>();
|
||||
|
||||
EXPECT_TRUE(passed);
|
||||
}
|
||||
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
Loading…
Reference in New Issue
Block a user