diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index b5f905b..df0296b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -886,9 +886,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); flash::apply_dropout( - scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + scores, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, AtomLayoutMS ); } @@ -1446,9 +1445,8 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in // Need col to be multiples of 32, since we're doing dropout with block of 16 x 32 static_assert(MMA_N_SdP % 2 == 0); int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2); - Tensor scores_dropped = make_tensor(scores.data(), flash::convert_layout_rowcol_Aregs(scores.layout())); flash::apply_dropout( - scores_dropped, params.p_dropout_in_uint8_t, seed, offset, + scores, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, AtomLayoutMS ); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 9cd4049..ca58e66 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -399,27 +399,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert scores from fp32 to fp16/bf16 Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor acc_s_f16 = flash::convert_type(acc_s); - Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout()); + Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); flash::apply_dropout( - tOrPdrop, params.p_dropout_in_uint8_t, seed, offset, + acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); cute::copy(acc_s_f16, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps); } - // if (cute::thread0()) { print(tOrP); } + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (cute::thread0()) { print(tOrP); } flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); // if (cute::thread0()) { print(scores); } @@ -484,26 +484,26 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); Tensor rP = flash::convert_type(scores); - // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) - // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. - Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { Tensor acc_s_f16 = flash::convert_type(acc_s); - Tensor tOrPdrop = make_tensor(acc_s_f16.data(), tOrP.layout()); + Tensor acc_s_f16_drop = make_tensor(acc_s_f16.data(), rP.layout()); flash::apply_dropout( - tOrPdrop, params.p_dropout_in_uint8_t, seed, offset, + acc_s_f16_drop, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps ); cute::copy(acc_s_f16, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } if (Is_dropout) { - flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset, + flash::apply_dropout(rP, params.p_dropout_in_uint8_t, seed, offset, block_row_idx, block_col_idx, kNWarps); } + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); } diff --git a/csrc/flash_attn/src/softmax.h b/csrc/flash_attn/src/softmax.h index 09a93f1..df449aa 100644 --- a/csrc/flash_attn/src/softmax.h +++ b/csrc/flash_attn/src/softmax.h @@ -213,10 +213,12 @@ inline __device__ void apply_mask_causal_w_idx( } template -inline __device__ void apply_dropout(Tensor &tensor, uint8_t p_dropout_in_uint8_t, +inline __device__ void apply_dropout(Tensor &tensor_, uint8_t p_dropout_in_uint8_t, unsigned long long seed, unsigned long long offset, int block_row_start, int block_col_start, int block_row_stride) { + // tensor_ has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_rowcol_dropout(tensor_.layout())); // tensor has shape (8, MMA_M, MMA_N / 2) using T = typename Engine::value_type; auto encode_dropout = [](bool keep, T val) { diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 31023c5..db02c80 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -211,6 +211,20 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +template +inline __device__ auto convert_layout_rowcol_dropout(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template inline __device__ auto convert_type(Tensor const &tensor) { using From_type = typename Engine::value_type; diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 37585ed..f446a4b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -545,7 +545,7 @@ def get_dropout_fraction( @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize("dtype", [torch.float16]) @pytest.mark.parametrize("deterministic", [False, True]) -# @pytest.mark.parametrize("deterministic", [True]) +# @pytest.mark.parametrize("deterministic", [False]) @pytest.mark.parametrize("alibi", [False, True]) # @pytest.mark.parametrize("alibi", [False]) @pytest.mark.parametrize("local", [False, True])