810 lines
28 KiB
C++
810 lines
28 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cute/atom/mma_traits_sm90.hpp"
|
|
#include "cute/atom/mma_traits_sm90_gmma.hpp"
|
|
#include "cute/atom/copy_traits_sm90.hpp"
|
|
|
|
#include "cutlass/detail/dependent_false.hpp"
|
|
#include "cutlass/detail/layout.hpp"
|
|
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
|
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
|
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
|
#include "cutlass/epilogue/thread/linear_combination_generic.h"
|
|
#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h"
|
|
#include "cutlass/epilogue/fusion/callbacks.hpp"
|
|
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
|
|
|
#if defined(__CUDACC_RTC__)
|
|
#include <cuda/std/type_traits>
|
|
#else
|
|
#include <type_traits>
|
|
#endif
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass::epilogue::collective {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
|
|
// Returns the parameterized dispatch policy for the TMA epilogue
|
|
template<class TileShapeMNK, class EpilogueTileMN, class ElementC, class ElementD, class Schedule>
|
|
constexpr auto
|
|
sm90_get_tma_dispatch_policy() {
|
|
using namespace cute;
|
|
|
|
constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{}));
|
|
constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128);
|
|
constexpr int ReuseSmemC = (sizeof_bits_v<ElementC> == sizeof_bits_v<ElementD>) && (sizeof_bits_v<ElementD> > 8);
|
|
constexpr int StagesD = 2;
|
|
constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles;
|
|
|
|
return Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>{};
|
|
}
|
|
|
|
// Returns the smem layout atom to be used for C or D matrix
|
|
template<class GmemStrideType, class Element, class EpilogueTile_MN>
|
|
constexpr auto
|
|
sm90_get_epilogue_smem_swizzle_layout_atom() {
|
|
using namespace cute;
|
|
|
|
// ColMajor C/D (M-major)
|
|
if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) {
|
|
return cutlass::gemm::collective::detail::ss_smem_selector<
|
|
cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{}))
|
|
>();
|
|
}
|
|
// RowMajor C/D (N-major)
|
|
else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) {
|
|
return cutlass::gemm::collective::detail::ss_smem_selector<
|
|
cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{}))
|
|
>();
|
|
}
|
|
else {
|
|
static_assert(cutlass::detail::dependent_false<GmemStrideType>, "Unsupported gmem layout.");
|
|
}
|
|
}
|
|
|
|
// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one.
|
|
template <class ElementD, class EpilogueTileType, class Schedule, class TileShape_MNK>
|
|
constexpr auto
|
|
sm90_compute_tile_shape_or_override() {
|
|
if constexpr (cute::is_same_v<EpilogueTileType, EpilogueTileAuto>) {
|
|
|
|
if constexpr (detail::sm90_is_cooperative_v<Schedule>) {
|
|
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
|
|
if constexpr (size<0>(TileShape_MNK{}) >= 128) {
|
|
return Shape<_128, N_tile>{};
|
|
}
|
|
else {
|
|
return Shape<_64, N_tile>{};
|
|
}
|
|
}
|
|
else if constexpr (detail::sm90_is_warp_specialized_v<Schedule>) {
|
|
if constexpr (sizeof_bits_v<ElementD> == 8) {
|
|
using N_tile = decltype(cute::min(_64{}, get<1>(TileShape_MNK{})));
|
|
return Shape<_64, N_tile>{};
|
|
}
|
|
else {
|
|
using N_tile = decltype(cute::min(_32{}, get<1>(TileShape_MNK{})));
|
|
return Shape<_64,N_tile>{};
|
|
}
|
|
}
|
|
else {
|
|
static_assert(cutlass::detail::dependent_false<Schedule>, "Unsupported schedule.");
|
|
}
|
|
}
|
|
else if constexpr (cute::is_tuple<EpilogueTileType>::value) {
|
|
EpilogueTileType epi_tile;
|
|
constexpr int M = size<0>(shape(epi_tile));
|
|
constexpr int N = size<1>(shape(epi_tile));
|
|
|
|
static_assert(!is_layout<EpilogueTileType>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
|
static_assert(M == 64 && detail::sm90_is_warp_specialized_v<Schedule> ||
|
|
M == 128 && detail::sm90_is_cooperative_v<Schedule>, "Unsupported tile shape");
|
|
static_assert(N % 16 == 0, "Unsupported tile shape");
|
|
|
|
return epi_tile;
|
|
}
|
|
else {
|
|
static_assert(cutlass::detail::dependent_false<EpilogueTileType>, "Invalid type for EpilogueTileType.");
|
|
}
|
|
}
|
|
|
|
// Selects the largest vectorized smem store atom available
|
|
template <class GmemStrideTypeD, class ElementD>
|
|
constexpr auto
|
|
sm90_get_smem_store_op_for_accumulator() {
|
|
using namespace cute;
|
|
|
|
if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) {
|
|
return SM90_U16x8_STSM_T{};
|
|
}
|
|
else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) {
|
|
return SM90_U32x4_STSM_N{};
|
|
}
|
|
else {
|
|
// auto-vectorizing store
|
|
return AutoVectorizingCopyWithAssumedAlignment{};
|
|
}
|
|
}
|
|
|
|
// Selects the largest vectorized smem load atom available
|
|
template <class GmemStrideTypeC, class ElementC>
|
|
constexpr auto
|
|
sm90_get_smem_load_op_for_source() {
|
|
using namespace cute;
|
|
|
|
// Reuse the logic from smem store selector
|
|
using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator<GmemStrideTypeC, ElementC>());
|
|
|
|
if constexpr (cute::is_same_v<SmemStoreOp, SM90_U16x8_STSM_T>) {
|
|
return SM75_U16x8_LDSM_T{};
|
|
}
|
|
else if constexpr (cute::is_same_v<SmemStoreOp, SM90_U32x4_STSM_N>) {
|
|
return SM75_U32x4_LDSM_N{};
|
|
}
|
|
else {
|
|
// auto-vectorizing load
|
|
return AutoVectorizingCopyWithAssumedAlignment{};
|
|
}
|
|
}
|
|
|
|
// callbacks builder with TMA aux out
|
|
template <
|
|
int StagesC,
|
|
int StagesD,
|
|
int FragmentSize,
|
|
bool ReuseSmemC,
|
|
class FusionOp,
|
|
class TileShape_MNK,
|
|
class EpilogueTile_MN,
|
|
class ElementAccumulator
|
|
>
|
|
struct CallbacksBuilder<
|
|
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
|
FusionOp,
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementAccumulator,
|
|
enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor
|
|
&& not is_subbyte_v<typename FusionOp::ElementAux>>
|
|
> {
|
|
using GmemStrideTypeAux = gemm::TagToStrideC_t<typename FusionOp::GmemLayoutTagAux>;
|
|
using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<
|
|
GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>());
|
|
using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator<
|
|
GmemStrideTypeAux, typename FusionOp::ElementAux>());
|
|
using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source<
|
|
GmemStrideTypeAux, typename FusionOp::ElementAux>());
|
|
using SmemCopyOpAux = conditional_t<FusionOp::IsAuxOutSupported, CopyOpR2S, CopyOpS2R>;
|
|
|
|
using Callbacks = fusion::FusionCallbacks<
|
|
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
|
FusionOp, TileShape_MNK, EpilogueTile_MN,
|
|
SmemLayoutAtomAux, SmemCopyOpAux
|
|
>;
|
|
};
|
|
|
|
template <
|
|
int StagesC,
|
|
int StagesD,
|
|
int FragmentSize,
|
|
bool ReuseSmemC,
|
|
class FusionOp,
|
|
class TileShape_MNK,
|
|
class EpilogueTile_MN,
|
|
class ElementAccumulator
|
|
>
|
|
struct CallbacksBuilder<
|
|
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
|
FusionOp,
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementAccumulator,
|
|
enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor
|
|
&& sizeof_bits_v<typename FusionOp::ElementAux> == 1>
|
|
> {
|
|
using Callbacks = fusion::FusionCallbacks<
|
|
Sm90TmaWarpSpecialized<StagesC, StagesD, FragmentSize, ReuseSmemC>,
|
|
FusionOp, TileShape_MNK, EpilogueTile_MN,
|
|
Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem
|
|
>;
|
|
};
|
|
|
|
// Helper for building TMA warp-specialized collective epilogues, specialized by
|
|
// the fusion operation performed and the dispatch policy to use.
|
|
template <
|
|
class TileShape_MNK,
|
|
class EpilogueTile_MN,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC_,
|
|
class GmemLayoutTagC_,
|
|
int AlignmentC,
|
|
class ElementD_,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class FusionOpOrCallbacks,
|
|
class DispatchPolicy
|
|
>
|
|
struct Sm90TmaBuilderImpl {
|
|
// Passing void D disables destination store + smem allocation
|
|
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
|
fusion::get_element_aux_t<FusionOpOrCallbacks>, ElementD_>;
|
|
|
|
// Passing void C disables source load + smem allocation
|
|
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
|
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
|
|
|
using GmemStrideTypeC = cutlass::detail::TagToStrideC_t<GmemLayoutTagC>;
|
|
using GmemStrideTypeD = cutlass::detail::TagToStrideC_t<GmemLayoutTagD>;
|
|
|
|
using CopyOpS2G =
|
|
SM90_TMA_STORE
|
|
;
|
|
using CopyOpG2S =
|
|
SM90_TMA_LOAD
|
|
;
|
|
|
|
// TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks
|
|
// instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination
|
|
using FusionCallbacks =
|
|
typename CallbacksBuilder<
|
|
DispatchPolicy,
|
|
FusionOpOrCallbacks,
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementAccumulator
|
|
>::Callbacks;
|
|
|
|
using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue<
|
|
DispatchPolicy,
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementC_, // Need to pass void through to expose via GemmUniversal
|
|
GmemStrideTypeC,
|
|
ElementD_,
|
|
GmemStrideTypeD,
|
|
FusionCallbacks,
|
|
CopyOpG2S,
|
|
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
|
|
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
|
|
CopyOpS2G,
|
|
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
|
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
|
|
>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
// Descriptor classes for defining EVT nodes
|
|
// Some of the epilogue visitor nodes require non-intuitive template arguments
|
|
// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the
|
|
// builder classes. Here we provide a set of descriptor classes that resolve
|
|
// these template arguments from more intuitive types such as Stride, Layout
|
|
|
|
// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD
|
|
template<
|
|
typename TileShape_MNK,
|
|
typename EpilogueTileType,
|
|
typename ElementC,
|
|
typename ElementD,
|
|
typename Schedule
|
|
>
|
|
struct EpilogueDescriptor {
|
|
using TileShape = TileShape_MNK;
|
|
using EpilogueTile =
|
|
decltype(
|
|
detail::sm90_compute_tile_shape_or_override<
|
|
ElementD, EpilogueTileType, Schedule, TileShape_MNK
|
|
>()
|
|
);
|
|
using DispatchPolicy =
|
|
decltype(
|
|
detail::sm90_get_tma_dispatch_policy<
|
|
TileShape_MNK, EpilogueTile,
|
|
ElementC, ElementD, Schedule
|
|
>()
|
|
);
|
|
constexpr static int StagesC = DispatchPolicy::StagesC;
|
|
constexpr static int StagesD = DispatchPolicy::StagesD;
|
|
};
|
|
|
|
// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node
|
|
template<
|
|
typename EpilogueDescriptor,
|
|
typename StrideOrLayoutTag,
|
|
typename ElementAux
|
|
>
|
|
struct AuxLoadDescriptor {
|
|
constexpr static int Stages = EpilogueDescriptor::StagesC;
|
|
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
|
|
using Element = ElementAux;
|
|
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
|
|
using SmemLayoutAtom =
|
|
decltype(
|
|
detail::sm90_get_epilogue_smem_swizzle_layout_atom<
|
|
Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile
|
|
>()
|
|
);
|
|
using CopyOpS2R =
|
|
decltype(detail::sm90_get_smem_load_op_for_source<Stride, ElementAux>());
|
|
};
|
|
|
|
// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node
|
|
template<
|
|
typename EpilogueDescriptor,
|
|
typename StrideOrLayoutTag,
|
|
typename ElementAux
|
|
>
|
|
struct AuxStoreDescriptor {
|
|
constexpr static int Stages = EpilogueDescriptor::StagesD;
|
|
using EpilogueTile = typename EpilogueDescriptor::EpilogueTile;
|
|
using Element = ElementAux;
|
|
using Stride = cutlass::detail::TagToStrideC_t<StrideOrLayoutTag>;
|
|
using SmemLayoutAtom =
|
|
decltype(
|
|
detail::sm90_get_epilogue_smem_swizzle_layout_atom<
|
|
Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile
|
|
>()
|
|
);
|
|
using CopyOpR2S =
|
|
decltype(detail::sm90_get_smem_store_op_for_accumulator<Stride, ElementAux>());
|
|
};
|
|
|
|
template<
|
|
typename EpilogueDescriptor,
|
|
typename ElementVector
|
|
>
|
|
struct RowBroadcastDescriptor {
|
|
constexpr static int Stages = ceil_div(
|
|
EpilogueDescriptor::StagesC,
|
|
size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{}))
|
|
) + 1;
|
|
|
|
using Element = ElementVector;
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// No-smem builder
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC_,
|
|
class GmemLayoutTagC_,
|
|
int AlignmentC,
|
|
class ElementD,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class Schedule,
|
|
FloatRoundStyle RoundStyle
|
|
>
|
|
struct CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC_,
|
|
GmemLayoutTagC_,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
Schedule,
|
|
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
|
|
cute::enable_if_t<cute::is_same_v<Schedule, NoSmemWarpSpecialized> ||
|
|
cute::is_same_v<Schedule, PtrArrayNoSmemWarpSpecialized> >> {
|
|
|
|
// Passing void C disables source load
|
|
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
|
ElementD, ElementC_>; // prevents cute breakages
|
|
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
|
GmemLayoutTagD, GmemLayoutTagC_>;
|
|
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
|
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
|
|
|
static constexpr int FragmentSize = 1;
|
|
using ThreadOp = thread::LinearCombination<
|
|
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
|
ScaleType, RoundStyle, ElementC>;
|
|
|
|
using CollectiveOp = cute::conditional_t<
|
|
cute::is_same_v<Schedule, NoSmemWarpSpecialized>,
|
|
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
|
cutlass::epilogue::collective::DefaultEpilogue<
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
|
ThreadOp,
|
|
cutlass::gemm::EpilogueDefault>>,
|
|
// Epilogue for Ptr-Array and Grouped Gemm
|
|
cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
|
cutlass::epilogue::collective::DefaultEpilogueArray<
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
|
ThreadOp,
|
|
Schedule>>
|
|
>;
|
|
};
|
|
|
|
// Tma warp-specialized builder
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC,
|
|
class GmemLayoutTagC,
|
|
int AlignmentC,
|
|
class ElementD_,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class Schedule,
|
|
class FusionOperation
|
|
>
|
|
struct CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD_,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
Schedule,
|
|
FusionOperation,
|
|
cute::enable_if_t<cute::is_same_v<Schedule, TmaWarpSpecialized> ||
|
|
cute::is_same_v<Schedule, TmaWarpSpecializedCooperative> >> {
|
|
private:
|
|
using ElementD = cute::conditional_t<cute::is_void_v<ElementD_>,
|
|
fusion::get_element_aux_t<FusionOperation>, ElementD_>;
|
|
using EpilogueTile_MN =
|
|
decltype(detail::sm90_compute_tile_shape_or_override<ElementD, EpilogueTileType, Schedule, TileShape_MNK>());
|
|
using DispatchPolicy =
|
|
decltype(detail::sm90_get_tma_dispatch_policy<TileShape_MNK,EpilogueTile_MN,ElementC,ElementD,Schedule>());
|
|
|
|
public:
|
|
using CollectiveOp =
|
|
typename detail::Sm90TmaBuilderImpl<
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD_,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
FusionOperation,
|
|
DispatchPolicy
|
|
>::CollectiveOp;
|
|
};
|
|
|
|
// Auto builder
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC,
|
|
class GmemLayoutTagC,
|
|
int AlignmentC,
|
|
class ElementD,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class FusionOperation
|
|
>
|
|
struct CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
EpilogueScheduleAuto,
|
|
FusionOperation,
|
|
void> {
|
|
private:
|
|
static_assert(cute::is_same_v<FusionOperation, fusion::LinearCombination<ElementD,ElementCompute,ElementC,ElementCompute>>,
|
|
"Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead.");
|
|
|
|
// Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance)
|
|
// since TMA epilogues are not compatible with non-TMA non-WS mainloops
|
|
using EpilogueSchedule = NoSmemWarpSpecialized;
|
|
using _CollectiveBuilder = CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
EpilogueSchedule,
|
|
FusionOperation
|
|
>;
|
|
|
|
public:
|
|
using CollectiveOp = typename _CollectiveBuilder::CollectiveOp;
|
|
};
|
|
|
|
// DEPRECATED Tma warp-specialized builder for elementwise fusion
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC,
|
|
class GmemLayoutTagC,
|
|
int AlignmentC,
|
|
class ElementD,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class Schedule,
|
|
class UnusedFusionOp
|
|
>
|
|
struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]]
|
|
CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
Schedule,
|
|
UnusedFusionOp,
|
|
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule> ||
|
|
cute::is_base_of_v<TmaWarpSpecializedCooperativeElementwiseBase, Schedule> >> {
|
|
private:
|
|
using FusionOp =
|
|
fusion::LinCombEltAct<Schedule::template ActivationFunctor, ElementD, ElementCompute, ElementC, ElementCompute, Schedule::Round>;
|
|
using ImplSchedule =
|
|
cute::conditional_t<cute::is_base_of_v<TmaWarpSpecializedElementwiseBase, Schedule>,
|
|
TmaWarpSpecialized, TmaWarpSpecializedCooperative>;
|
|
|
|
public:
|
|
using CollectiveOp =
|
|
typename CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC,
|
|
GmemLayoutTagC,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
ImplSchedule,
|
|
FusionOp
|
|
>::CollectiveOp;
|
|
};
|
|
|
|
// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC_,
|
|
class GmemLayoutTagC_,
|
|
int AlignmentC,
|
|
class ElementD,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
class Schedule,
|
|
class UnusedFusionOp
|
|
>
|
|
struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]]
|
|
CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC_,
|
|
GmemLayoutTagC_,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
Schedule,
|
|
UnusedFusionOp,
|
|
cute::enable_if_t<cute::is_base_of_v<TmaWarpSpecializedBiasElementwiseBase, Schedule> ||
|
|
cute::is_base_of_v<TmaWarpSpecializedCooperativeBiasElementwiseBase, Schedule> >> {
|
|
private:
|
|
using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override<
|
|
ElementD, EpilogueTileType, Schedule, TileShape_MNK>());
|
|
// MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's
|
|
// defined as decltype of a detail::sm90_get_tma_dispatch_policy call.
|
|
// Instead, we paste in the contents of that function. A natural refactoring
|
|
// would be to create a type alias in the detail namespace.
|
|
using DispatchPolicy = Sm90TmaWarpSpecialized<
|
|
/* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})),
|
|
/* StagesD = */ 2,
|
|
/* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v<Schedule> ? 256 : 128),
|
|
/* ReuseSmemC = */ sizeof_bits_v<ElementC_> == sizeof_bits_v<ElementD>
|
|
>;
|
|
|
|
using GmemStrideTypeAux = gemm::TagToStrideC_t<GmemLayoutTagD>;
|
|
using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<
|
|
GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>());
|
|
using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator<
|
|
GmemStrideTypeAux, typename Schedule::ElementT>());
|
|
using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux<
|
|
GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
|
typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute
|
|
>;
|
|
using FusionCallbacksAux = fusion::FusionCallbacks<
|
|
DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux
|
|
>;
|
|
|
|
using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct<
|
|
Schedule::template ActivationFunctor, ElementD, ElementCompute,
|
|
typename Schedule::ElementBias, ElementC_, ElementCompute
|
|
>;
|
|
using FusionCallbacksNoAux = fusion::FusionCallbacks<
|
|
DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN
|
|
>;
|
|
|
|
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,ElementD,ElementC_>; // prevents void ref breakages
|
|
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,GmemLayoutTagD,GmemLayoutTagC_>;
|
|
|
|
using GmemStrideTypeC = gemm::TagToStrideC_t<GmemLayoutTagC>;
|
|
using GmemStrideTypeD = gemm::TagToStrideC_t<GmemLayoutTagD>;
|
|
|
|
public:
|
|
using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise<
|
|
DispatchPolicy::StagesC,
|
|
DispatchPolicy::StagesD,
|
|
DispatchPolicy::FragmentSize,
|
|
TileShape_MNK,
|
|
EpilogueTile_MN,
|
|
ElementC_, // Need to pass void through to expose via GemmUniversal
|
|
GmemStrideTypeC,
|
|
ElementD,
|
|
GmemStrideTypeD,
|
|
cute::conditional_t<Schedule::StoreT, FusionCallbacksAux, FusionCallbacksNoAux>,
|
|
SM90_TMA_LOAD,
|
|
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeC, ElementC, EpilogueTile_MN>()),
|
|
decltype(detail::sm90_get_smem_load_op_for_source<GmemStrideTypeC, ElementC>()),
|
|
SM90_TMA_STORE,
|
|
decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom<GmemStrideTypeD, ElementD, EpilogueTile_MN>()),
|
|
decltype(detail::sm90_get_smem_store_op_for_accumulator<GmemStrideTypeD, ElementD>())
|
|
>;
|
|
};
|
|
|
|
// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels
|
|
// since swapping NNN kernels input matrix and transposing its output at the same time then
|
|
// we can get TTN kernel.
|
|
template <
|
|
class TileShape_MNK,
|
|
class ClusterShape_MNK,
|
|
class EpilogueTileType,
|
|
class ElementAccumulator,
|
|
class ElementCompute,
|
|
class ElementC_,
|
|
class GmemLayoutTagC_,
|
|
int AlignmentC,
|
|
class ElementD,
|
|
class GmemLayoutTagD,
|
|
int AlignmentD,
|
|
FloatRoundStyle RoundStyle
|
|
>
|
|
struct CollectiveBuilder<
|
|
arch::Sm90,
|
|
arch::OpClassTensorOp,
|
|
TileShape_MNK,
|
|
ClusterShape_MNK,
|
|
EpilogueTileType,
|
|
ElementAccumulator,
|
|
ElementCompute,
|
|
ElementC_,
|
|
GmemLayoutTagC_,
|
|
AlignmentC,
|
|
ElementD,
|
|
GmemLayoutTagD,
|
|
AlignmentD,
|
|
cutlass::gemm::EpilogueTransposed,
|
|
fusion::LinearCombination<ElementD,ElementCompute,ElementC_,ElementCompute,RoundStyle>,
|
|
void> {
|
|
// Passing void C disables source load
|
|
using ElementC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
|
ElementD, ElementC_>; // prevents cute breakages
|
|
using GmemLayoutTagC = cute::conditional_t<cute::is_void_v<ElementC_>,
|
|
GmemLayoutTagD, GmemLayoutTagC_>;
|
|
static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v<ElementC_> ?
|
|
thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default;
|
|
|
|
static constexpr int FragmentSize = 1;
|
|
using ThreadOp = thread::LinearCombination<
|
|
ElementD, FragmentSize, ElementAccumulator, ElementCompute,
|
|
ScaleType, RoundStyle, ElementC>;
|
|
|
|
using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
|
|
cutlass::epilogue::collective::DefaultEpilogue<
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagC>,
|
|
cutlass::detail::TagToStrideC_t<GmemLayoutTagD>,
|
|
ThreadOp,
|
|
cutlass::gemm::EpilogueTransposed>
|
|
>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass::epilogue::collective
|