Simplify SmemLayoutVtransposed in kernel_traits.h
This commit is contained in:
parent
c9861a032d
commit
8d1b169ed1
@ -975,9 +975,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
|
||||
// Layout p_l = tPrP.layout();
|
||||
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
|
||||
// flash::gemm_A_in_regs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
// flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
|
||||
// flash::gemm_A_in_regs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
// flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
|
||||
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
|
||||
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
|
||||
|
||||
@ -369,11 +369,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
@ -444,7 +444,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
@ -483,19 +483,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
flash::apply_mask_local(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
@ -528,7 +528,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
@ -977,11 +977,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
@ -1027,7 +1027,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
@ -1069,11 +1069,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
if (Has_alibi) {
|
||||
flash::apply_alibi<Is_causal>(
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
scores,
|
||||
n_block * kBlockN,
|
||||
binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q,
|
||||
binfo.actual_seqlen_q,
|
||||
kNWarps * 16,
|
||||
alibi_slope
|
||||
);
|
||||
@ -1094,7 +1094,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposedNoSwizzle{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
|
||||
using SmemLayoutVtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -247,19 +237,9 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomKtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
|
||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
using SmemLayoutKtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
@ -277,30 +257,15 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>;
|
||||
using SmemLayoutAtomPdStransposed = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
|
||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposed{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposedNoSwizzle{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemLayoutPdStransposed = decltype(
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomQdOtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
|
||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
|
||||
@ -162,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
inline __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
@ -188,10 +188,7 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
|
||||
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
|
||||
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -207,13 +204,9 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
// TD [2023-08-13]: Same error as above on Cutlass 3.2
|
||||
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
// get<0, 1>(l),
|
||||
// get<1, 1, 1>(l));
|
||||
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
|
||||
get<1>(get<0>(l)),
|
||||
get<1>(get<1>(get<1>(l))));
|
||||
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
get<0, 1>(l),
|
||||
get<1, 1, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Loading…
Reference in New Issue
Block a user