Refactor masking in fwd pass into 1 object

This commit is contained in:
Tri Dao 2024-01-20 17:39:53 -08:00
parent ed4959b2eb
commit 66a127aef8
5 changed files with 169 additions and 147 deletions

View File

@ -25,9 +25,8 @@ struct Dropout {
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
int block_row_start, int block_col_start, int block_row_stride) {
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout()));
// tensor has shape (8, MMA_M, MMA_N / 2)
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout()));
using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));

View File

@ -527,16 +527,16 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
static_assert(MMA_N_SdP % 2 == 0);
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
scores, block_row_idx, block_col_idx, AtomLayoutMS
acc_s, block_row_idx, block_col_idx, AtomLayoutMS
);
}
// Convert scores from fp32 to fp16/bf16
Tensor rP = !Is_dropout
? flash::convert_type<Element>(scores)
: flash::convert_type_relu<Element>(scores);
// Reshape rP from (nrow=(2, MMA_N), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
? flash::convert_type<Element>(acc_s)
: flash::convert_type_relu<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
Tensor tPrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMmaSdP>(rP.layout()));
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
// if (cute::thread0()) { print(tPaP); }

View File

@ -265,8 +265,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
flash::Softmax<2 * size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
@ -304,43 +304,9 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
}
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
// static_assert(decltype(size<0>(taccScS))::value == 4);
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
// Tensor idx_rowcol = make_tensor(taccScS.data(), flash::convert_layout_acc_rowcol(taccScS.layout()));
// flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM);
// Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row
flash::apply_mask_local</*HasWSLeft=*/Is_local>(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
// m_block * kBlockM + get<0>(idx_row(0)),
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
);
flash::cp_async_wait<0>();
__syncthreads();
@ -358,26 +324,26 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Convert acc_s from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout());
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps
rP_drop, block_row_idx, block_col_idx, kNWarps
);
cute::copy(acc_s_f16, tSgS);
cute::copy(rP_drop, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
// if (cute::thread0()) { print(tOrP); }
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
@ -416,44 +382,31 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
cute::cp_async_fence();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
Tensor rP = flash::convert_type<Element>(acc_s);
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) {
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout());
Tensor rP_drop = make_fragment_like(rP);
cute::copy(rP, rP_drop);
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
acc_s_f16_drop, block_row_idx, block_col_idx, kNWarps
rP_drop, block_row_idx, block_col_idx, kNWarps
);
cute::copy(acc_s_f16, tSgS);
cute::copy(rP_drop, tSgS);
tSgS.data() = tSgS.data() + (-kBlockN);
}
if (Is_dropout) {
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
}
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}
@ -845,7 +798,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::Softmax<2 * size<1>(acc_o)> softmax;
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
@ -883,27 +836,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
}
// if (cute::thread0()) { print(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN.
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else {
flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
);
flash::cp_async_wait<0>();
__syncthreads();
@ -925,14 +860,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
// Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
// Convert acc_s from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= n_block_min) {
@ -968,28 +902,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
cute::cp_async_fence();
}
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
mask.template apply_mask</*Causal_mask=*/false>(
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
);
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
Tensor rP = flash::convert_type<Element>(acc_s);
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
}

View File

@ -107,4 +107,107 @@ __forceinline__ __device__ void apply_mask_causal_w_idx(
}
}
template <bool Is_causal, bool Is_local, bool Has_alibi>
struct Mask {
const int max_seqlen_k, max_seqlen_q;
const int window_size_left, window_size_right;
const float alibi_slope;
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
const int window_size_left, const int window_size_right,
const float alibi_slope=0.f)
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q)
, window_size_left(window_size_left)
, window_size_right(window_size_right)
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
};
// Causal_mask: whether this particular iteration needs causal masking
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
static_assert(Layout::rank == 3, "Only support 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
if constexpr (Need_masking) {
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout()));
// Do we need both row and column indices, or just column incides?
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Col_idx_only) {
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// No causal, no local
if constexpr (Has_alibi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
if constexpr (!Is_even_MN) {
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
}
}
}
}
} else {
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
if constexpr (Has_alibi) {
if constexpr (Is_causal) {
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
} else {
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
if constexpr (Causal_mask) {
if (col_idx >= col_idx_limit_right) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (Is_local) {
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
// Causal and Local already handles MN masking
if (col_idx >= max_seqlen_k) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
}
}
}
}
}
}
}
}
};
};
} // namespace flash

View File

@ -193,34 +193,33 @@ __forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
template<typename MMA_traits, typename Layout>
__forceinline__ __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<0, 1>(l),
get<1, 1, 1>(l));
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
template<typename Layout>
__forceinline__ __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
using X = Underscore;
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<2>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<0, 1>(l),
get<1, 1, 1>(l));
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
};
////////////////////////////////////////////////////////////////////////////////////////////////////