1159 lines
48 KiB
C++
1159 lines
48 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Functor performing elementwise operations used by epilogues.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/arch/barrier.h"
|
|
#include "cutlass/epilogue/dispatch_policy.hpp"
|
|
#include "cutlass/epilogue/collective/detail.hpp"
|
|
#include "cutlass/epilogue/thread/scale_type.h"
|
|
#include "cutlass/epilogue/fusion/callbacks.hpp"
|
|
#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp"
|
|
#include "cutlass/detail/collective.hpp"
|
|
#include "cutlass/detail/layout.hpp"
|
|
#include "cutlass/trace.h"
|
|
#include "cutlass/cuda_host_adapter.hpp"
|
|
|
|
#include "cute/tensor.hpp"
|
|
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace epilogue {
|
|
namespace collective {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
int StagesC_,
|
|
int StagesD_,
|
|
int FragmentSize_,
|
|
bool ReuseSmemC_,
|
|
bool DelayTmaStore_,
|
|
int NumEpilogueWarpGroups_,
|
|
class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K)
|
|
class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N)
|
|
class ElementC_,
|
|
class StrideC_,
|
|
class ElementD_,
|
|
class StrideD_,
|
|
class FusionCallbacks_,
|
|
class CopyOpG2S_,
|
|
class SmemLayoutAtomC_,
|
|
class CopyOpS2R_,
|
|
class CopyOpS2G_,
|
|
class SmemLayoutAtomD_,
|
|
class CopyOpR2S_,
|
|
class CopyAtomC_
|
|
>
|
|
class CollectiveEpilogue<
|
|
Sm90PtrArrayTmaWarpSpecialized<StagesC_,
|
|
StagesD_,
|
|
FragmentSize_,
|
|
ReuseSmemC_,
|
|
DelayTmaStore_,
|
|
NumEpilogueWarpGroups_
|
|
>,
|
|
CtaTileMNK_,
|
|
EpilogueTile_,
|
|
ElementC_,
|
|
StrideC_,
|
|
ElementD_,
|
|
StrideD_,
|
|
FusionCallbacks_,
|
|
CopyOpG2S_,
|
|
SmemLayoutAtomC_,
|
|
CopyOpS2R_,
|
|
CopyOpS2G_,
|
|
SmemLayoutAtomD_,
|
|
CopyOpR2S_,
|
|
CopyAtomC_
|
|
> {
|
|
public:
|
|
//
|
|
// Type Aliases
|
|
//
|
|
using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized<StagesC_,
|
|
StagesD_,
|
|
FragmentSize_,
|
|
ReuseSmemC_,
|
|
DelayTmaStore_,
|
|
NumEpilogueWarpGroups_
|
|
>;
|
|
using CtaTileMNK = CtaTileMNK_;
|
|
using EpilogueTile = EpilogueTile_;
|
|
using FusionCallbacks = FusionCallbacks_;
|
|
using ElementC = ElementC_;
|
|
using StrideC = StrideC_;
|
|
using InternalStrideC = cute::remove_pointer_t<StrideC>;
|
|
using ElementD = ElementD_;
|
|
using StrideD = StrideD_;
|
|
using InternalStrideD = cute::remove_pointer_t<StrideD>;
|
|
using CopyOpG2S = CopyOpG2S_;
|
|
using SmemLayoutAtomC = SmemLayoutAtomC_;
|
|
using CopyOpS2R = CopyOpS2R_;
|
|
using CopyOpS2G = CopyOpS2G_;
|
|
using SmemLayoutAtomD = SmemLayoutAtomD_;
|
|
using CopyOpR2S = CopyOpR2S_;
|
|
using CopyAtomC = CopyAtomC_;
|
|
|
|
|
|
using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::Operation;
|
|
using GmemTiledCopyC = CopyOpG2S;
|
|
using GmemTiledCopyD = CopyOpS2G;
|
|
|
|
static_assert(!is_layout<EpilogueTile>::value && is_tuple<EpilogueTile>::value, "EpilogueTile must be a cute::Tile or cute::Shape");
|
|
static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]");
|
|
static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]");
|
|
static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M");
|
|
static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N");
|
|
static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]");
|
|
static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]");
|
|
|
|
private:
|
|
constexpr static bool is_source_supported = not cute::is_void_v<ElementC>;
|
|
constexpr static bool is_destination_supported = not cute::is_void_v<ElementD>;
|
|
using NonVoidElementD = cute::conditional_t<not is_destination_supported,fusion::get_element_aux_t<FusionCallbacks>, ElementD>;
|
|
static_assert(not cute::is_void_v<NonVoidElementD>, "SmemElementD is void");
|
|
using NonVoidElementC = cute::conditional_t<not is_source_supported,NonVoidElementD,ElementC>; // prevents void ref breakages
|
|
|
|
using SmemElementC = typename cutlass::detail::get_unpacked_element_type<NonVoidElementC>::type;
|
|
using SmemElementD = typename cutlass::detail::get_unpacked_element_type<NonVoidElementD>::type;
|
|
|
|
constexpr static int StagesC = StagesC_;
|
|
constexpr static int StagesD = StagesD_;
|
|
constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported;
|
|
constexpr static bool DelayTmaStore = DelayTmaStore_;
|
|
|
|
constexpr static bool is_m_major_C = detail::is_m_major<InternalStrideC>();
|
|
constexpr static bool is_m_major_D = detail::is_m_major<InternalStrideD>();
|
|
|
|
constexpr static bool is_im2col_C = cute::is_same_v<CopyOpG2S, SM90_TMA_LOAD_IM2COL>;
|
|
constexpr static bool is_im2col_D = cute::is_same_v<CopyOpS2G, SM90_TMA_STORE_IM2COL>;
|
|
|
|
using SmemLayoutC = decltype(tile_to_shape(
|
|
SmemLayoutAtomC{},
|
|
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<StagesC>{}),
|
|
cute::conditional_t<is_m_major_C, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
|
using SmemLayoutD = decltype(tile_to_shape(
|
|
SmemLayoutAtomD{},
|
|
make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int<ReuseSmemC ? StagesC : StagesD>{}),
|
|
cute::conditional_t<is_m_major_D, Step<_2,_1,_3>, Step<_1,_2,_3>>{} ));
|
|
|
|
constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC
|
|
&& cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{}));
|
|
static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met");
|
|
|
|
constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{});
|
|
constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{});
|
|
constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD);
|
|
|
|
using SmemArrayTypeC = cute::ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>>;
|
|
using SmemArrayTypeD = cute::ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>>;
|
|
|
|
using EmptyType = cute::tuple<>;
|
|
using SmemCStorage = cute::conditional_t<is_source_supported and (not ReuseSmemC),
|
|
SmemArrayTypeC,
|
|
EmptyType>;
|
|
using SmemDStorage = cute::conditional_t<is_destination_supported,
|
|
SmemArrayTypeD,
|
|
EmptyType>;
|
|
|
|
struct CollectiveStorageWithC {
|
|
alignas(SmemAlignmentC) ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>> smem_C;
|
|
alignas(SmemAlignmentD) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
|
|
};
|
|
|
|
union CollectiveStorageWithoutC {
|
|
cute::array<SmemElementC, 0> smem_C;
|
|
alignas(SmemAlignmentD) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
|
|
};
|
|
|
|
union CollectiveStorageReuseC {
|
|
alignas(MaxSmemAlignment) ArrayEngine<SmemElementC, cosize_v<SmemLayoutC>> smem_C;
|
|
alignas(MaxSmemAlignment) ArrayEngine<SmemElementD, cosize_v<SmemLayoutD>> smem_D;
|
|
};
|
|
|
|
public:
|
|
// TMA pipeline for loading C
|
|
using LoadPipeline = cutlass::PipelineTransactionAsync<StagesC>;
|
|
using LoadPipelineState = cutlass::PipelineState<StagesC>;
|
|
constexpr static uint32_t TmaTransactionBytes =
|
|
(size(take<0,2>(SmemLayoutC{})) * static_cast<uint32_t>(sizeof_bits<SmemElementC>::value)) / 8;
|
|
constexpr static bool RequiresTransactionBytes = true;
|
|
|
|
constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_;
|
|
|
|
// TMA pipeline for storing D
|
|
using StorePipeline = cute::conditional_t<ReuseSmemC,
|
|
cutlass::PipelineTmaStore<StagesC, StagesD-1>,
|
|
cutlass::PipelineTmaStore<StagesD>>;
|
|
using StorePipelineState = cutlass::PipelineState<ReuseSmemC ? StagesC : StagesD>;
|
|
|
|
struct SharedStorage {
|
|
struct TensorStorage {
|
|
using CollectiveStorage = cute::conditional_t<not is_source_supported, CollectiveStorageWithoutC,
|
|
cute::conditional_t<ReuseSmemC, CollectiveStorageReuseC, CollectiveStorageWithC>>;
|
|
CollectiveStorage collective;
|
|
|
|
using FusionStorage = typename FusionCallbacks::SharedStorage;
|
|
FusionStorage thread;
|
|
} tensors;
|
|
|
|
struct TensorMapStorage : cute::aligned_struct<128> {
|
|
cute::TmaDescriptor smem_tensormap_C;
|
|
cute::array<cute::TmaDescriptor, NumEpilogueWarpGroups> smem_tensormap_D;
|
|
} tensormaps;
|
|
|
|
using PipelineStorage = typename LoadPipeline::SharedStorage;
|
|
PipelineStorage pipeline;
|
|
};
|
|
using TensorStorage = typename SharedStorage::TensorStorage;
|
|
using TensorMapStorage = typename SharedStorage::TensorMapStorage;
|
|
using PipelineStorage = typename SharedStorage::PipelineStorage;
|
|
|
|
static constexpr bool IsGroupedGemmKernel = !cute::is_same_v<InternalStrideC, StrideC>;
|
|
|
|
// Host side epilogue arguments
|
|
struct Arguments {
|
|
typename FusionCallbacks::Arguments thread{};
|
|
ElementC const** ptr_C = nullptr;
|
|
StrideC dC;
|
|
ElementD ** ptr_D = nullptr;
|
|
StrideD dD;
|
|
};
|
|
|
|
// Device side epilogue params
|
|
struct Params {
|
|
using TMA_C = decltype(make_tma_copy(
|
|
CopyOpG2S{},
|
|
make_tensor(make_gmem_ptr(static_cast<NonVoidElementC const*>(nullptr)),
|
|
repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}),
|
|
take<0,2>(SmemLayoutC{}),
|
|
EpilogueTile{},
|
|
_1{}));
|
|
|
|
using TMA_D = decltype(make_tma_copy(
|
|
CopyOpS2G{},
|
|
make_tensor(make_gmem_ptr(static_cast<NonVoidElementD const*>(nullptr)),
|
|
repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}),
|
|
take<0,2>(SmemLayoutD{}),
|
|
EpilogueTile{},
|
|
_1{}));
|
|
|
|
typename FusionCallbacks::Params thread{};
|
|
TMA_C tma_load_c;
|
|
TMA_D tma_store_d;
|
|
cute::TmaDescriptor* tensormaps;
|
|
ElementC const** ptr_C;
|
|
StrideC dC;
|
|
ElementD** ptr_D;
|
|
StrideD dD;
|
|
uint32_t tma_transaction_bytes = TmaTransactionBytes;
|
|
};
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
template <class ProblemShape>
|
|
static constexpr Params
|
|
to_underlying_arguments(
|
|
ProblemShape const& problem_shape,
|
|
Arguments const& args,
|
|
[[maybe_unused]] void* workspace) {
|
|
// These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc.
|
|
// These will be replaced with correct values before the initial tma load.
|
|
auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1));
|
|
auto init_M = get<0>(init_shape);
|
|
auto init_N = get<1>(init_shape);
|
|
auto init_L = get<3>(init_shape);
|
|
|
|
static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D");
|
|
|
|
InternalStrideC stride_c;
|
|
InternalStrideD stride_d;
|
|
if constexpr (IsGroupedGemmKernel) {
|
|
// Strides for Grouped Gemm will be replaced prior to the first access regardless.
|
|
stride_c = InternalStrideC{};
|
|
stride_d = InternalStrideD{};
|
|
}
|
|
else {
|
|
// Tensor shapes for Ptr-Array are initialized correctly only here.
|
|
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1);
|
|
init_M = get<0>(problem_shape_MNKL);
|
|
init_N = get<1>(problem_shape_MNKL);
|
|
init_L = get<3>(problem_shape_MNKL);
|
|
|
|
stride_c = args.dC;
|
|
stride_d = args.dD;
|
|
}
|
|
|
|
uint32_t transaction_bytes = TmaTransactionBytes;
|
|
typename Params::TMA_C tma_load_c = {};
|
|
if constexpr (is_source_supported) {
|
|
ElementC const* ptr_C_first_batch = reinterpret_cast<ElementC const*>(args.ptr_C);
|
|
Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{})));
|
|
tma_load_c = make_tma_copy(
|
|
CopyOpG2S{},
|
|
tensor_c,
|
|
take<0,2>(SmemLayoutC{}),
|
|
EpilogueTile{},
|
|
_1{});
|
|
|
|
}
|
|
|
|
typename Params::TMA_D tma_store_d;
|
|
if constexpr (is_destination_supported) {
|
|
ElementD const* ptr_D_first_batch = reinterpret_cast<ElementD const*>(args.ptr_D);
|
|
Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{})));
|
|
tma_store_d = make_tma_copy(
|
|
CopyOpS2G{},
|
|
tensor_d,
|
|
take<0,2>(SmemLayoutD{}),
|
|
EpilogueTile{},
|
|
_1{});
|
|
}
|
|
|
|
auto fusion_workspace = static_cast<char*>(workspace);
|
|
auto fusion_workspace_size = FusionCallbacks::get_workspace_size(problem_shape, args.thread);
|
|
auto tma_descriptor_workspace = reinterpret_cast<cute::TmaDescriptor*>(
|
|
static_cast<char*>(workspace) + fusion_workspace_size);
|
|
|
|
return {
|
|
FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace),
|
|
tma_load_c,
|
|
tma_store_d,
|
|
tma_descriptor_workspace,
|
|
args.ptr_C,
|
|
args.dC,
|
|
args.ptr_D,
|
|
args.dD,
|
|
transaction_bytes,
|
|
};
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static size_t
|
|
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) {
|
|
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
|
|
auto descriptors_shape = cute::make_shape(sm_count, Int<NumInputTensors>{});
|
|
constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor);
|
|
// Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies
|
|
return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread);
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static cutlass::Status
|
|
initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream,
|
|
CudaHostAdapter* cuda_adapter = nullptr) {
|
|
return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter);
|
|
}
|
|
|
|
template <class ProblemShape>
|
|
static bool
|
|
can_implement(
|
|
ProblemShape problem_shape,
|
|
[[maybe_unused]] Arguments const& args) {
|
|
|
|
bool implementable = true;
|
|
bool fusion_implementable = true;
|
|
|
|
if (problem_shape.is_host_problem_shape_available()) {
|
|
for (int i = 0; i < problem_shape.groups(); ++i) {
|
|
auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1);
|
|
auto [M,N,K,L] = problem_shape_MNKL;
|
|
|
|
if constexpr (is_destination_supported) {
|
|
constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits<ElementD>();
|
|
constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits<ElementD>::value;
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_D>(cute::make_shape(M,N,L), InternalStrideD{});
|
|
}
|
|
|
|
if constexpr (not cute::is_void_v<ElementC>) {
|
|
constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits<ElementC>();
|
|
constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits<ElementC>::value;
|
|
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_C>(cute::make_shape(M,N,L), InternalStrideC{});
|
|
}
|
|
|
|
fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread);
|
|
}
|
|
}
|
|
else {
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n");
|
|
}
|
|
|
|
if (!implementable) {
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
|
|
}
|
|
|
|
if (!fusion_implementable) {
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n");
|
|
}
|
|
|
|
bool beta_implementable = true;
|
|
|
|
if constexpr (cute::is_void_v<ElementC>) {
|
|
if constexpr (detail::has_beta<Arguments>::value) {
|
|
beta_implementable = args.thread.beta == 0.0;
|
|
}
|
|
if constexpr (detail::has_beta_ptr<Arguments>::value) {
|
|
beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr;
|
|
}
|
|
}
|
|
|
|
if (!beta_implementable) {
|
|
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n");
|
|
}
|
|
|
|
return implementable && fusion_implementable && beta_implementable;
|
|
}
|
|
|
|
template<class TileShapeMNK>
|
|
CUTLASS_HOST_DEVICE
|
|
static constexpr int
|
|
get_load_pipe_increment(TileShapeMNK tile_shape_MNK) {
|
|
// Compute number of epilogue subtiles
|
|
return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{}));
|
|
}
|
|
|
|
template<class TileShapeMNK>
|
|
CUTLASS_HOST_DEVICE
|
|
static constexpr int
|
|
get_store_pipe_increment(TileShapeMNK tile_shape_MNK) {
|
|
return get_load_pipe_increment(tile_shape_MNK);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors)
|
|
: params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {}
|
|
|
|
CUTLASS_DEVICE
|
|
bool
|
|
is_producer_load_needed() const {
|
|
return fusion_callbacks.is_producer_load_needed();
|
|
}
|
|
|
|
CUTLASS_DEVICE auto
|
|
load_init(
|
|
Params const& params,
|
|
TensorMapStorage& shared_tensormaps,
|
|
int32_t sm_count,
|
|
int32_t sm_idx) {
|
|
// Initialize tma for loading
|
|
constexpr bool IsLoad = true;
|
|
auto load_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, 0);
|
|
return load_tensormaps;
|
|
}
|
|
|
|
template<
|
|
class ProblemShapeMNKL,
|
|
class TileShapeMNK,
|
|
class TileCoordMNKL,
|
|
class TiledMma,
|
|
class TensorMapC,
|
|
__CUTE_REQUIRES(std::is_pointer_v<TensorMapC>)
|
|
>
|
|
CUTLASS_DEVICE auto
|
|
load(
|
|
LoadPipeline load_pipeline,
|
|
LoadPipelineState load_pipe_producer_state,
|
|
ProblemShapeMNKL problem_shape_mnkl,
|
|
TileShapeMNK tile_shape_MNK,
|
|
TileCoordMNKL tile_coord_mnkl,
|
|
TiledMma tiled_mma,
|
|
int thread_idx,
|
|
TensorStorage& shared_tensors,
|
|
TensorMapC const& load_tensormap,
|
|
int subtile_idx=-1,
|
|
bool wait_until_load_finishes = false) {
|
|
using namespace cute;
|
|
|
|
// Indexing variables
|
|
auto [M, N, K, L] = problem_shape_mnkl;
|
|
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
|
|
|
|
static_assert(!is_im2col_D, "Do not support im2col");
|
|
|
|
auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{});
|
|
|
|
// Represent the full source tensor, slice to get the tile this CTA is currently responsible for
|
|
Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L)
|
|
Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{}));
|
|
Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
|
|
|
|
// Apply epilogue subtile, get matching smem tensor
|
|
auto ptr_sC = shared_tensors.collective.smem_C.begin();
|
|
Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
|
Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
|
|
|
|
// Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_)
|
|
ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{});
|
|
Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N)
|
|
Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C)
|
|
|
|
// Get the fusion callbacks for the producer load warp
|
|
auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{
|
|
problem_shape_mnkl,
|
|
CtaTileMNK{},
|
|
tile_coord_mnkl,
|
|
tiled_mma,
|
|
EpilogueTile{},
|
|
thread_idx
|
|
};
|
|
auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args);
|
|
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
|
|
|
|
LoadPipelineState last_load_producer_state = load_pipe_producer_state;
|
|
|
|
// Predication for TMA load (one thread issues TMA load)
|
|
bool issue_tma_load = cute::elect_one_sync();
|
|
|
|
// Acquire the lock for the first stage
|
|
load_pipeline.producer_acquire(load_pipe_producer_state);
|
|
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
|
|
|
|
// Pre-loop fusion callback entry point
|
|
pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load);
|
|
|
|
LoadPipelineState prior_state = load_pipe_producer_state;
|
|
|
|
bool did_load = false;
|
|
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) {
|
|
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gC_epi)) + epi_m) != subtile_idx) {
|
|
continue;
|
|
}
|
|
// Acquire the lock for this stage
|
|
constexpr uint16_t mcast_mask = 0;
|
|
uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state);
|
|
load_pipeline.producer_acquire(load_pipe_producer_state);
|
|
|
|
// Loop fusion callback entry point
|
|
pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load);
|
|
|
|
// Execute the TMA load for C if needed
|
|
if (is_C_load_needed) {
|
|
if (issue_tma_load) {
|
|
copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask),
|
|
bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index()));
|
|
load_pipeline.producer_expect_transaction(load_pipe_producer_state);
|
|
}
|
|
last_load_producer_state = load_pipe_producer_state;
|
|
did_load = true;
|
|
}
|
|
|
|
// Commit TMA loads for this stage and release the lock
|
|
load_pipeline.producer_commit(load_pipe_producer_state);
|
|
++load_pipe_producer_state;
|
|
}
|
|
}
|
|
|
|
// Post-loop fusion callback entry point
|
|
pld_callbacks.end();
|
|
|
|
if (wait_until_load_finishes && did_load) {
|
|
typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state =
|
|
{last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()};
|
|
load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state);
|
|
}
|
|
|
|
return load_pipe_producer_state;
|
|
}
|
|
|
|
CUTLASS_DEVICE auto
|
|
load_tail(
|
|
LoadPipeline load_pipeline,
|
|
LoadPipelineState load_pipe_producer_state) {
|
|
|
|
if (!fusion_callbacks.is_producer_load_needed()) {
|
|
return load_pipe_producer_state;
|
|
}
|
|
|
|
bool issue_tma_load = cute::elect_one_sync();
|
|
if (issue_tma_load) {
|
|
load_pipeline.producer_tail(load_pipe_producer_state);
|
|
}
|
|
|
|
return load_pipe_producer_state;
|
|
}
|
|
|
|
template<
|
|
class ProblemShapeMNKL,
|
|
class TileShapeMNK,
|
|
class TileCoordMNKL,
|
|
class AccEngine, class AccLayout,
|
|
class TiledMma,
|
|
class TensorMapD
|
|
>
|
|
CUTLASS_DEVICE auto
|
|
store(
|
|
LoadPipeline load_pipeline,
|
|
LoadPipelineState load_pipe_consumer_state,
|
|
StorePipeline store_pipeline,
|
|
StorePipelineState store_pipe_producer_state,
|
|
ProblemShapeMNKL problem_shape_mnkl,
|
|
TileShapeMNK tile_shape_MNK,
|
|
TileCoordMNKL tile_coord_mnkl,
|
|
cute::Tensor<AccEngine,AccLayout> accumulators,
|
|
TiledMma tiled_mma,
|
|
int thread_idx,
|
|
TensorStorage& shared_tensors,
|
|
TensorMapD const& store_tensormap,
|
|
int subtile_idx=-1) {
|
|
|
|
using namespace cute;
|
|
using ElementAccumulator = typename AccEngine::value_type;
|
|
using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits<FusionCallbacks>::ElementCompute;
|
|
using ElementCompute = cute::conditional_t<cute::is_void_v<ElementCompute_>,ElementAccumulator,ElementCompute_>;
|
|
|
|
static_assert(is_rmem<AccEngine>::value, "Accumulator must be RF resident.");
|
|
static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)");
|
|
static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4");
|
|
static_assert(is_static<TileShapeMNK>::value, "TileShapeMNK must be static");
|
|
static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3");
|
|
static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4");
|
|
|
|
// Indexing variables
|
|
auto [M, N, K, L] = problem_shape_mnkl;
|
|
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl;
|
|
|
|
|
|
static_assert(!is_im2col_D, "Do not support im2col");
|
|
|
|
auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{});
|
|
|
|
// Represent the full output tensor, slice to get the tile this CTA is responsible for
|
|
Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L)
|
|
Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{}));
|
|
Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N)
|
|
|
|
// Apply epilogue subtiling
|
|
Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
|
|
|
|
// Construct the corresponding pipelined smem tensors
|
|
auto ptr_sC = shared_tensors.collective.smem_C.begin();
|
|
auto ptr_sD = shared_tensors.collective.smem_D.begin();
|
|
Tensor sC_epi = cute::as_position_independent_swizzle_tensor(
|
|
make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C)
|
|
Tensor sD_epi = cute::as_position_independent_swizzle_tensor(
|
|
make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D)
|
|
|
|
TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma);
|
|
|
|
// (t)hread-partition for (r)egister to (s)mem copy (tRS_)
|
|
TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom<CopyOpR2S,SmemElementD>{}, tiled_copy_C_atom);
|
|
ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx);
|
|
Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N)
|
|
Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D)
|
|
|
|
auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc);
|
|
auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc);
|
|
auto epi_tile_m = size<0>(EpilogueTile{});
|
|
auto epi_tile_n = size<1>(EpilogueTile{});
|
|
|
|
// Allocate D registers
|
|
Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi))));
|
|
Tensor tRS_rD = make_tensor<SmemElementD>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
|
|
|
|
// Vectorized fragment view
|
|
constexpr int FragmentSize = DispatchPolicy::FragmentSize;
|
|
Tensor tRS_rAcc_frg = recast<Array<ElementAccumulator, FragmentSize>>(tRS_rAcc);
|
|
Tensor tRS_rD_frg = recast<Array<SmemElementD , FragmentSize>>(tRS_rD);
|
|
CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly");
|
|
|
|
// (t)hread-partition for (s)mem to (r)egister copy (tSR_)
|
|
TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom<CopyOpS2R, SmemElementC>{}, tiled_copy_C_atom);
|
|
ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx);
|
|
Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C)
|
|
Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N)
|
|
|
|
// Allocate C registers
|
|
// If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type
|
|
// to eliminate some redundant pack+unpack instruction sequences for sub-word types
|
|
constexpr bool IsDirectS2R = cute::is_same_v<CopyOpS2R, AutoVectorizingCopyWithAssumedAlignment<128>>
|
|
&& decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1;
|
|
using RegisterElementC = cute::conditional_t<IsDirectS2R, ElementCompute, SmemElementC>;
|
|
Tensor tRS_rC = make_tensor<RegisterElementC>(tRS_rD_layout); // (R2S,R2S_M,R2S_N)
|
|
Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N)
|
|
|
|
// thread(b)lock-partition for (s)mem to (g)mem copy (bSG_)
|
|
ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{});
|
|
Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D)
|
|
Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N)
|
|
|
|
// OOB predication for tile quantization "residue"
|
|
// Absolute coordinate tensors (dynamic)
|
|
Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N)
|
|
Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N)
|
|
Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
|
|
// Relative coordinate tensors (static)
|
|
Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N)
|
|
Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N)
|
|
// Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate
|
|
auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n)
|
|
auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n)
|
|
|
|
CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M");
|
|
|
|
CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N");
|
|
// Get the fusion callbacks for the consumer store warps
|
|
constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout
|
|
auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{
|
|
problem_shape_mnkl,
|
|
CtaTileMNK{},
|
|
tile_coord_mnkl,
|
|
tiled_mma,
|
|
EpilogueTile{},
|
|
tiled_r2s,
|
|
cD,
|
|
residue_cD,
|
|
tRS_cD,
|
|
residue_tRS_cD,
|
|
tRS_rC,
|
|
thread_idx
|
|
};
|
|
auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks<RefSrc>(cst_args);
|
|
bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed();
|
|
bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed();
|
|
|
|
// Thread synchronizer for previously issued waits or fences
|
|
// to ensure visibility of smem reads/writes to threads or TMA unit
|
|
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
|
|
|
// Predication for TMA store (one warp issues TMA store)
|
|
bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0;
|
|
|
|
// In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight.
|
|
// The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can
|
|
// only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks.
|
|
// store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion.
|
|
LoadPipelineState load_wait_state = load_pipe_consumer_state;
|
|
if constexpr (ReuseSmemC) {
|
|
load_wait_state = store_pipe_producer_state;
|
|
load_wait_state.phase_ ^= 1;
|
|
}
|
|
|
|
// We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions
|
|
// Sync requirements of smem reuse may preclude this optimization
|
|
// Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD
|
|
int epi_m_prev = 0, epi_n_prev = 0;
|
|
static_assert(not (DelayTmaStore and ReuseSmemC and StagesC == StagesD), "This TMA epilogue configuration will deadlock");
|
|
|
|
// The TMA store sequence for one subtile iteration
|
|
auto tma_store_fn = [&] (int epi_m, int epi_n) {
|
|
// Write the tile from smem to gmem with TMA
|
|
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
|
synchronize(); // ensure all threads have issued their async fence
|
|
if constexpr (is_destination_supported) {
|
|
if (issue_tma_store) {
|
|
copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n));
|
|
}
|
|
}
|
|
|
|
// Post async fence, pre TMA commit callback entry point
|
|
cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store);
|
|
|
|
// Commit the TMA stores for this stage
|
|
if (issue_tma_store) {
|
|
store_pipeline.producer_commit(store_pipe_producer_state);
|
|
}
|
|
++store_pipe_producer_state;
|
|
++issued_stores;
|
|
|
|
// Wait for the next smem buffer to be available
|
|
if (issue_tma_store) {
|
|
store_pipeline.producer_acquire(store_pipe_producer_state);
|
|
}
|
|
synchronize();
|
|
|
|
if constexpr (ReuseSmemC) {
|
|
// producer_acquire returns when at most StagesD-1 committed stores are pending
|
|
bool store_finished = issued_stores > StorePipeline::UnacquiredStages;
|
|
// Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits
|
|
if (store_finished) {
|
|
if (is_producer_load_needed) {
|
|
load_pipeline.consumer_release(load_pipe_consumer_state);
|
|
}
|
|
++load_pipe_consumer_state;
|
|
}
|
|
}
|
|
};
|
|
|
|
//
|
|
// BEGIN EPILOGUE
|
|
//
|
|
|
|
// Pre-loop fusion callback entry point
|
|
cst_callbacks.begin();
|
|
|
|
// For each output tile
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) {
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) {
|
|
bool is_first_iteration = epi_m == 0 && epi_n == 0;
|
|
bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1;
|
|
|
|
if (subtile_idx != -1 && (epi_n * static_cast<int>(size<2>(gD_epi)) + epi_m) != subtile_idx) {
|
|
continue;
|
|
}
|
|
|
|
cst_callbacks.begin_loop(epi_m, epi_n);
|
|
|
|
if (is_producer_load_needed) {
|
|
// Wait for the producer load to fill smem
|
|
load_pipeline.consumer_wait(load_wait_state);
|
|
|
|
if (is_C_load_needed) {
|
|
// Copy source tile from smem to register
|
|
copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC);
|
|
}
|
|
}
|
|
|
|
// First loop fusion callback entry point
|
|
cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed);
|
|
|
|
if (is_producer_load_needed) {
|
|
if constexpr (not ReuseSmemC) {
|
|
// Let producer load warp know smem buffers are consumed and empty
|
|
cutlass::arch::fence_view_async_shared();
|
|
load_pipeline.consumer_release(load_pipe_consumer_state);
|
|
++load_pipe_consumer_state;
|
|
}
|
|
++load_wait_state;
|
|
}
|
|
|
|
int mma_m = epi_m;
|
|
int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n;
|
|
Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n);
|
|
|
|
// Vectorized fragment loop with visitor callback entry point
|
|
int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n);
|
|
int r2s_v = epi_n_in_mma * size(tRS_rD_frg);
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) {
|
|
tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n);
|
|
}
|
|
// The latest we can delay the TMA store is right before the smem store of the next iteration
|
|
// since the current TMA store needs to be committed before we can acquire the next smem buffer
|
|
if constexpr (DelayTmaStore) {
|
|
// Issue TMA stores for the previous subtile
|
|
if (not is_first_iteration and subtile_idx == -1) {
|
|
tma_store_fn(epi_m_prev, epi_n_prev);
|
|
}
|
|
epi_m_prev = epi_m;
|
|
epi_n_prev = epi_n;
|
|
}
|
|
|
|
// Smem reduction callback entry point using current store buffer for workspace
|
|
cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()),
|
|
synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg);
|
|
|
|
// Copy tile from register to smem
|
|
if constexpr (is_destination_supported) {
|
|
copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index()));
|
|
}
|
|
|
|
// Post reduction, pre TMA store callback entry point
|
|
constexpr bool issue_smem_store = true; // No smem store predication
|
|
cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store);
|
|
|
|
if constexpr (not DelayTmaStore) {
|
|
// Issue TMA stores for this subtile
|
|
tma_store_fn(epi_m, epi_n);
|
|
}
|
|
|
|
cst_callbacks.end_loop(epi_m, epi_n);
|
|
|
|
} // for epi_m
|
|
} // for epi_n
|
|
|
|
if constexpr (DelayTmaStore) {
|
|
// Issue TMA stores for the last subtile
|
|
tma_store_fn(epi_m_prev, epi_n_prev);
|
|
}
|
|
|
|
// Post-loop fusion callback entry point
|
|
cst_callbacks.end();
|
|
|
|
return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state);
|
|
}
|
|
|
|
CUTLASS_DEVICE auto
|
|
store_tail(
|
|
LoadPipeline load_pipeline,
|
|
LoadPipelineState load_pipe_consumer_state,
|
|
StorePipeline store_pipeline,
|
|
StorePipelineState store_pipe_producer_state) {
|
|
// wait for all TMA stores to complete
|
|
store_pipeline.producer_tail(store_pipe_producer_state);
|
|
// reset store counter
|
|
issued_stores = 0;
|
|
|
|
if constexpr (ReuseSmemC) {
|
|
if (fusion_callbacks.is_producer_load_needed()) {
|
|
// Issue releases on up to StagesD-1 previously issued TMA stores
|
|
constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{}));
|
|
CUTLASS_PRAGMA_UNROLL
|
|
for (int stage = 0; stage < release_stages; ++stage) {
|
|
load_pipeline.consumer_release(load_pipe_consumer_state);
|
|
++load_pipe_consumer_state;
|
|
}
|
|
}
|
|
}
|
|
|
|
return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state);
|
|
}
|
|
|
|
CUTLASS_DEVICE auto
|
|
store_init(
|
|
Params const& params,
|
|
TensorMapStorage& shared_tensormaps,
|
|
int32_t sm_count,
|
|
int32_t sm_idx,
|
|
int32_t warp_group_idx) {
|
|
int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup;
|
|
// Since only one warp issues TMA store, we only need that one warp to initialize tensormaps
|
|
if (warp_idx_in_warp_group == 0) {
|
|
// Initialize tma
|
|
constexpr bool IsLoad = false;
|
|
auto store_tensormaps = tensormaps_init<IsLoad>(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx);
|
|
return store_tensormaps;
|
|
}
|
|
TmaDescriptor* null_tma_desc = nullptr;
|
|
return cute::make_tuple(null_tma_desc);
|
|
}
|
|
|
|
//
|
|
// Methods to perform different parts of TMA/Tensormap modifications
|
|
//
|
|
|
|
template <bool IsLoad>
|
|
CUTLASS_DEVICE auto
|
|
tensormaps_init(
|
|
Params const& params,
|
|
TensorMapStorage& shared_tensormaps,
|
|
int32_t sm_count,
|
|
int32_t sm_idx,
|
|
int32_t warp_group_idx) {
|
|
|
|
constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v<ElementC> ? 0 : 1);
|
|
Layout desc_layout = make_layout(make_shape(sm_count, Int<NumInputTensors>{}));
|
|
|
|
Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors)
|
|
|
|
if constexpr (IsLoad) {
|
|
if (not cute::is_void_v<ElementC>) {
|
|
constexpr int C_tensormap_index = NumEpilogueWarpGroups;
|
|
Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{});
|
|
|
|
if (cute::elect_one_sync()) {
|
|
// Bringing tensormaps from params to smem for modification later
|
|
copy(recast<uint128_t>(pC_tensormap), recast<uint128_t>(sC_tensormap));
|
|
}
|
|
__syncwarp();
|
|
return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index));
|
|
}
|
|
TmaDescriptor* null_tma_desc = nullptr;
|
|
return cute::make_tuple(null_tma_desc);
|
|
}
|
|
else {
|
|
Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{});
|
|
Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{});
|
|
|
|
if (cute::elect_one_sync()) {
|
|
// Bringing tensormaps from params to smem for modification later
|
|
copy(recast<uint128_t>(pD_tensormap), recast<uint128_t>(sD_tensormap));
|
|
}
|
|
__syncwarp();
|
|
return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx));
|
|
}
|
|
}
|
|
|
|
// Replace address for the global tensor (to be done by single thread)
|
|
template <bool IsLoad>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_replace_global_address(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& params,
|
|
int32_t next_batch,
|
|
int32_t warp_group_idx) {
|
|
// Replacing global_address for the next batch
|
|
if constexpr (IsLoad) {
|
|
if constexpr (is_source_supported) {
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C,
|
|
params.ptr_C[next_batch]);
|
|
}
|
|
}
|
|
else if constexpr (is_destination_supported) {
|
|
cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
|
|
params.ptr_D[next_batch]);
|
|
}
|
|
}
|
|
|
|
// Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread)
|
|
template <bool IsLoad, class ProblemShape_MNKL>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_replace_global_tensor_properties(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& params,
|
|
int32_t next_group,
|
|
ProblemShape_MNKL problem_shape_mnkl,
|
|
int32_t warp_group_idx) {
|
|
const uint32_t M = get<0>(problem_shape_mnkl);
|
|
const uint32_t N = get<1>(problem_shape_mnkl);
|
|
// Only consider dimensions and strides that we need to recalculate and replace for each group
|
|
constexpr int TensorRank = rank(ProblemShape_MNKL{}) - 1; // excluding either M or N
|
|
static_assert(TensorRank == Int<3>{},
|
|
"Descriptor modification for global dims & strides expects rank as 3.");
|
|
|
|
cute::array<uint32_t, TensorRank> prob_shape = {1,1,1};
|
|
cute::array<uint64_t, TensorRank> prob_stride = {0,0,0};
|
|
|
|
if constexpr (IsLoad) {
|
|
if constexpr (is_source_supported) {
|
|
ElementC const* ptr_C = nullptr;
|
|
Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group]));
|
|
|
|
cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c,
|
|
prob_shape, prob_stride);
|
|
// Convert strides to byte strides
|
|
for (uint64_t& stride : prob_stride) {
|
|
stride = (stride * sizeof_bits_v<ElementC>) / 8;
|
|
}
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C,
|
|
prob_shape,
|
|
prob_stride);
|
|
}
|
|
}
|
|
else if constexpr (is_destination_supported) {
|
|
ElementD const* ptr_D = nullptr;
|
|
|
|
// tma_store_c should be a gmem_tensor, second argument should be a stride
|
|
|
|
Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group]));
|
|
|
|
cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d,
|
|
prob_shape, prob_stride);
|
|
// Convert strides to byte strides
|
|
for (uint64_t& stride : prob_stride) {
|
|
stride = (stride * sizeof_bits_v<ElementD>) / 8;
|
|
}
|
|
|
|
cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx],
|
|
prob_shape,
|
|
prob_stride);
|
|
}
|
|
}
|
|
|
|
template <bool IsLoad, class ProblemShape_MNKL>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_perform_update(
|
|
TensorMapStorage& shared_tensormaps,
|
|
Params const& params,
|
|
cute::TmaDescriptor const* tensormap,
|
|
ProblemShape_MNKL problem_shape_mnkl,
|
|
int32_t next_batch,
|
|
int32_t warp_group_idx) {
|
|
if (cute::elect_one_sync()) {
|
|
|
|
// Replacing global_address for the next batch
|
|
tensormaps_replace_global_address<IsLoad>(shared_tensormaps, params, next_batch, warp_group_idx);
|
|
|
|
if constexpr (IsGroupedGemmKernel) {
|
|
// Replacing global dims and strides for the next batch
|
|
tensormaps_replace_global_tensor_properties<IsLoad>(
|
|
shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool IsLoad>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_cp_fence_release(
|
|
TensorMapStorage& shared_tensormaps,
|
|
cute::TmaDescriptor const* tensormap,
|
|
[[maybe_unused]] uint32_t lane_predicate,
|
|
int32_t warp_group_idx = 0) {
|
|
// Entire warp must do this (ie its aligned)
|
|
if constexpr (IsLoad) {
|
|
if constexpr (is_source_supported) {
|
|
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C);
|
|
}
|
|
}
|
|
else if constexpr (is_destination_supported) {
|
|
tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]);
|
|
}
|
|
}
|
|
|
|
template <bool IsLoad>
|
|
CUTLASS_DEVICE
|
|
void
|
|
tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) {
|
|
if constexpr (IsLoad) {
|
|
if constexpr (not cute::is_void_v<ElementC>) {
|
|
cute::tma_descriptor_fence_acquire(tensormap);
|
|
}
|
|
}
|
|
else {
|
|
cute::tma_descriptor_fence_acquire(tensormap);
|
|
}
|
|
}
|
|
|
|
private:
|
|
Params const& params;
|
|
FusionCallbacks fusion_callbacks;
|
|
int issued_stores = 0;
|
|
};
|
|
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace collective
|
|
} // namespace epilogue
|
|
} // namespace cutlass
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|