From 6998e0ecdba7aa06e1dc22357ef9729ac06af5f4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Wed, 9 Nov 2022 08:17:37 -0800 Subject: [PATCH] Fix out-of-bound memory read --- .../src/fmha_dgrad_fp16_kernel_loop.sm80.cu | 2 +- .../src/fmha_dgrad_kernel_1xN_loop.h | 18 ++++++++++++++++++ csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 2 ++ tests/test_flash_attn.py | 9 +++++++-- 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu index a01d92a..085cc3f 100644 --- a/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu +++ b/csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu @@ -7,7 +7,7 @@ #include "fmha_dgrad_kernel_1xN_loop.h" // Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1). -// Parallelizing will have better occupancy, but has some overhead due to having to zero out dq +// Parallelizing will have better occupancy, but has some overhead due to having to zero out // dq_tmp and having to copy dq_tmp to dq. int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen, int blocksize, bool is_causal) { diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 1888ea5..52ce4c5 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -271,6 +271,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; + // Otherwise we'd be reading out-of-bound memory before the loop + if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) { + // Still need to zero out dk and dv before returning + static_assert(Smem_tile_dk::NUM_LDS == Smem_tile_dv::NUM_LDS); + uint4 dkv_out[Smem_tile_dk::NUM_LDS]; + #pragma unroll + for (int i = 0; i < Smem_tile_dk::NUM_LDS; ++i) { dkv_out[i] = make_uint4(0u, 0u, 0u, 0u); } + Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, + params.d, binfo, tidx, false); + if (!Is_first) { gmem_dk.move(loop_step_idx); } + gmem_dk.store(dkv_out); + Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts, + params.d, binfo, tidx, false); + if (!Is_first) { gmem_dv.move(loop_step_idx); } + gmem_dv.store(dkv_out); + return; + } + const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin; // Wind gmem tiles to the correct position. gmem_q.move(begin); diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index fd4621b..6c54566 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -280,6 +280,8 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // k * gridDim.z + 1 for integer k. const int begin_mod_z = begin % gridDim.z; begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z; + // Otherwise we'd be reading out-of-bound memory before the loop + if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return; const int steps_og = steps; steps -= begin; gmem_q.move(begin + blockIdx.z); diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index b85da45..0fc614b 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -12,6 +12,11 @@ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis +try: + from flash_attn.flash_attn_triton import flash_attn_func +except (ImportError, AttributeError): # Older version of Triton doesn't have tl.constexpr + flash_attn_func = None + is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5) is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0) @@ -857,9 +862,8 @@ def test_flash_attn_multigpu(): assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() -from flash_attn.flash_attn_triton import flash_attn_func - +@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old') @pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @@ -930,6 +934,7 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() +@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old') @pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16])