Fix out-of-bound memory read

This commit is contained in:
Tri Dao 2022-11-09 08:17:37 -08:00
parent 908a5b2244
commit 6998e0ecdb
4 changed files with 28 additions and 3 deletions

View File

@ -7,7 +7,7 @@
#include "fmha_dgrad_kernel_1xN_loop.h"
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
// Parallelizing will have better occupancy, but has some overhead due to having to zero out dq
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
// dq_tmp and having to copy dq_tmp to dq.
int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
int blocksize, bool is_causal) {

View File

@ -271,6 +271,24 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
// Otherwise we'd be reading out-of-bound memory before the loop
if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) {
// Still need to zero out dk and dv before returning
static_assert(Smem_tile_dk::NUM_LDS == Smem_tile_dv::NUM_LDS);
uint4 dkv_out[Smem_tile_dk::NUM_LDS];
#pragma unroll
for (int i = 0; i < Smem_tile_dk::NUM_LDS; ++i) { dkv_out[i] = make_uint4(0u, 0u, 0u, 0u); }
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
params.d, binfo, tidx, false);
if (!Is_first) { gmem_dk.move(loop_step_idx); }
gmem_dk.store(dkv_out);
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
params.d, binfo, tidx, false);
if (!Is_first) { gmem_dv.move(loop_step_idx); }
gmem_dv.store(dkv_out);
return;
}
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
// Wind gmem tiles to the correct position.
gmem_q.move(begin);

View File

@ -280,6 +280,8 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// k * gridDim.z + 1 for integer k.
const int begin_mod_z = begin % gridDim.z;
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
// Otherwise we'd be reading out-of-bound memory before the loop
if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return;
const int steps_og = steps;
steps -= begin;
gmem_q.move(begin + blockIdx.z);

View File

@ -12,6 +12,11 @@ from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_unpadded
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
try:
from flash_attn.flash_attn_triton import flash_attn_func
except (ImportError, AttributeError): # Older version of Triton doesn't have tl.constexpr
flash_attn_func = None
is_sm75 = torch.cuda.get_device_capability('cuda') == (7, 5)
is_sm80 = torch.cuda.get_device_capability('cuda') == (8, 0)
@ -857,9 +862,8 @@ def test_flash_attn_multigpu():
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old')
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@ -930,6 +934,7 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.skipif(flash_attn_func is None, reason='Triton is not installed or is too old')
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])