Merge pull request #1155 from ipiszy/fix
Fix out-of-bound writes for var-seq-len zero-length KVs
This commit is contained in:
commit
28e7f4ddbd
@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd {
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
|
||||
);
|
||||
static_assert(kBlockM <= NumMmaThreads);
|
||||
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
|
||||
if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
}
|
||||
int n_block_max = collective_mainloop.get_n_block_max(
|
||||
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
||||
if (Is_causal && n_block_max <= 0) {
|
||||
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
|
||||
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
||||
scheduler.broadcast_next_work(work_tile_info);
|
||||
continue;
|
||||
@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
||||
}
|
||||
int n_block_max = collective_mainloop.get_n_block_max(
|
||||
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
||||
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
||||
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
||||
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -236,8 +236,8 @@ def test_flash_attn_varlen_output(
|
||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
||||
|
||||
(
|
||||
@ -312,11 +312,16 @@ def test_flash_attn_varlen_output(
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
|
||||
dk_ref.masked_fill_(zero_masking, 0.0)
|
||||
dv_ref.masked_fill_(zero_masking, 0.0)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
dk_pt.masked_fill_(zero_masking, 0.0)
|
||||
dv_pt.masked_fill_(zero_masking, 0.0)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
|
||||
@ -5,16 +5,23 @@ from einops import rearrange, repeat
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
||||
assert mode in ["full", "random", "third"]
|
||||
if mode == "full":
|
||||
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
||||
elif mode == "random":
|
||||
lengths = torch.randint(
|
||||
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||
)
|
||||
elif mode == "third":
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||
|
||||
if zero_lengths:
|
||||
# Generate zero-lengths every 5 batches and the last batch.
|
||||
for i in range(batch_size):
|
||||
if i % 5 == 0:
|
||||
lengths[i] = 0
|
||||
lengths[-1] = 0
|
||||
padding_mask = (
|
||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||
)
|
||||
@ -251,4 +258,5 @@ def attention_ref(
|
||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||
if query_padding_mask is not None:
|
||||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||
output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user