apply_dropout now takes tensor of rowcol layout
This commit is contained in:
parent
d9cbcfb41c
commit
10dad61277
@ -886,9 +886,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
|
||||
scores, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, AtomLayoutMS
|
||||
);
|
||||
}
|
||||
@ -1446,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMmaSdP>(scores.layout()));
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
scores_dropped, params.p_dropout_in_uint8_t, seed, offset,
|
||||
scores, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, AtomLayoutMS
|
||||
);
|
||||
}
|
||||
|
||||
@ -399,27 +399,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
|
||||
Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout());
|
||||
Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout());
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrPdrop, params.p_dropout_in_uint8_t, seed, offset,
|
||||
acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
cute::copy(acc_s_f16, tSgS);
|
||||
tSgS.data() = tSgS.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
@ -484,26 +484,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor acc_s_f16 = flash::convert_type<Element>(acc_s);
|
||||
Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout());
|
||||
Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout());
|
||||
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrPdrop, params.p_dropout_in_uint8_t, seed, offset,
|
||||
acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
cute::copy(acc_s_f16, tSgS);
|
||||
tSgS.data() = tSgS.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
|
||||
@ -213,10 +213,12 @@ inline __device__ void apply_mask_causal_w_idx(
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
// tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout()));
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
|
||||
@ -211,6 +211,20 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
template<typename Layout>
|
||||
inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<2>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
get<0, 1>(l),
|
||||
get<1, 1, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
|
||||
@ -545,7 +545,7 @@ def get_dropout_fraction(
|
||||
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16])
|
||||
@pytest.mark.parametrize("deterministic", [False, True])
|
||||
# @pytest.mark.parametrize("deterministic", [True])
|
||||
# @pytest.mark.parametrize("deterministic", [False])
|
||||
@pytest.mark.parametrize("alibi", [False, True])
|
||||
# @pytest.mark.parametrize("alibi", [False])
|
||||
@pytest.mark.parametrize("local", [False, True])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user