Merge pull request #1155 from ipiszy/fix
Fix out-of-bound writes for var-seq-len zero-length KVs
This commit is contained in:
commit
28e7f4ddbd
@ -285,10 +285,10 @@ struct CollectiveEpilogueFwd {
|
|||||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
|
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.layout_O.shape()); }
|
||||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||||
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.layout_O.shape()) - m_block * kBlockM
|
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_traits_q.actual_seq_len - m_block * kBlockM
|
||||||
);
|
);
|
||||||
static_assert(kBlockM <= NumMmaThreads);
|
static_assert(kBlockM <= NumMmaThreads);
|
||||||
if (thread_idx < get<0>(epilogue_params.layout_LSE.shape()) - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
|
if (thread_idx < seqlen_traits_q.actual_seq_len - m_block * kBlockM) { gLSE(thread_idx) = -INFINITY; }
|
||||||
}
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|||||||
@ -123,7 +123,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
|||||||
}
|
}
|
||||||
int n_block_max = collective_mainloop.get_n_block_max(
|
int n_block_max = collective_mainloop.get_n_block_max(
|
||||||
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
||||||
if (Is_causal && n_block_max <= 0) {
|
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) {
|
||||||
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
|
||||||
scheduler.broadcast_next_work(work_tile_info);
|
scheduler.broadcast_next_work(work_tile_info);
|
||||||
continue;
|
continue;
|
||||||
@ -169,7 +169,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
|
|||||||
}
|
}
|
||||||
int n_block_max = collective_mainloop.get_n_block_max(
|
int n_block_max = collective_mainloop.get_n_block_max(
|
||||||
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
|
||||||
if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
if ((Is_causal || seqlen_traits_k.kUseVarSeqLen) && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
|
||||||
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
|
collective_epilogue.store_zero(epilogue_params, shared_storage, threadIdx.x - NumCopyThreads, block_coord, seqlen_traits_q);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -236,8 +236,8 @@ def test_flash_attn_varlen_output(
|
|||||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
|
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
|
||||||
)
|
)
|
||||||
|
|
||||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
|
||||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
|
||||||
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
||||||
|
|
||||||
(
|
(
|
||||||
@ -312,11 +312,16 @@ def test_flash_attn_varlen_output(
|
|||||||
dk_ref,
|
dk_ref,
|
||||||
dv_ref,
|
dv_ref,
|
||||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||||
|
zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
|
||||||
|
dk_ref.masked_fill_(zero_masking, 0.0)
|
||||||
|
dv_ref.masked_fill_(zero_masking, 0.0)
|
||||||
(
|
(
|
||||||
dq_pt,
|
dq_pt,
|
||||||
dk_pt,
|
dk_pt,
|
||||||
dv_pt,
|
dv_pt,
|
||||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||||
|
dk_pt.masked_fill_(zero_masking, 0.0)
|
||||||
|
dv_pt.masked_fill_(zero_masking, 0.0)
|
||||||
dq = dq_pad_fn(dq_unpad)
|
dq = dq_pad_fn(dq_unpad)
|
||||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||||
|
|||||||
@ -5,16 +5,23 @@ from einops import rearrange, repeat
|
|||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
|
|
||||||
|
|
||||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
|
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
||||||
assert mode in ["full", "random", "third"]
|
assert mode in ["full", "random", "third"]
|
||||||
if mode == "full":
|
if mode == "full":
|
||||||
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
||||||
elif mode == "random":
|
elif mode == "random":
|
||||||
lengths = torch.randint(
|
lengths = torch.randint(
|
||||||
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||||
)
|
)
|
||||||
elif mode == "third":
|
elif mode == "third":
|
||||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||||
|
|
||||||
|
if zero_lengths:
|
||||||
|
# Generate zero-lengths every 5 batches and the last batch.
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i % 5 == 0:
|
||||||
|
lengths[i] = 0
|
||||||
|
lengths[-1] = 0
|
||||||
padding_mask = (
|
padding_mask = (
|
||||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||||
)
|
)
|
||||||
@ -251,4 +258,5 @@ def attention_ref(
|
|||||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||||
if query_padding_mask is not None:
|
if query_padding_mask is not None:
|
||||||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||||
|
output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
|
||||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user