[Kernel] Update CUTLASS to 3.5.1 (#7085)
This commit is contained in:
parent
997cf78308
commit
8571ac4672
@ -193,8 +193,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.0
|
||||
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
# CUTLASS 3.5.1
|
||||
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
|
||||
# Shallow clone with depth 1
|
||||
GIT_SHALLOW TRUE
|
||||
GIT_PROGRESS TRUE
|
||||
@ -237,7 +237,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
|
||||
@ -64,8 +64,6 @@ using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
@ -73,14 +71,12 @@ template<
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||
struct SharedStorage {
|
||||
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params),
|
||||
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element* smem_row;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <int EpiTiles, class GTensor, class STensor>
|
||||
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||
: gRow(cute::forward<GTensor>(gRow)),
|
||||
sRow(cute::forward<STensor>(sRow)),
|
||||
params(params) {}
|
||||
|
||||
GTensor gRow; // (CTA_M,CTA_N)
|
||||
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||
if (!params.row_broadcast) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (issue_tma_load) {
|
||||
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||
// Issue the TMA bulk copy
|
||||
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||
cute::move(gRow), cute::move(sRow), params);
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <int EpiTiles, class RTensor, class STensor>
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||
params(params) {}
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, params(params_) {}
|
||||
|
||||
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tCrRow, *(params.ptr_row));
|
||||
fill(tSR_rRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
|
||||
@ -10,8 +10,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
@ -301,12 +299,14 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
||||
@ -18,8 +18,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleBDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
};
|
||||
|
||||
/*
|
||||
@ -154,12 +148,8 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using BiasDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, ElementD>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
||||
|
||||
public:
|
||||
@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user