add a unittest
This commit is contained in:
parent
a3a257c71d
commit
53537da422
@ -236,8 +236,8 @@ def test_flash_attn_varlen_output(
|
||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
|
||||
)
|
||||
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
|
||||
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
|
||||
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
|
||||
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
|
||||
|
||||
(
|
||||
@ -312,11 +312,16 @@ def test_flash_attn_varlen_output(
|
||||
dk_ref,
|
||||
dv_ref,
|
||||
) = torch.autograd.grad(out_ref, (q, k, v), g)
|
||||
zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
|
||||
dk_ref.masked_fill_(zero_masking, 0.0)
|
||||
dv_ref.masked_fill_(zero_masking, 0.0)
|
||||
(
|
||||
dq_pt,
|
||||
dk_pt,
|
||||
dv_pt,
|
||||
) = torch.autograd.grad(out_pt, (q, k, v), g)
|
||||
dk_pt.masked_fill_(zero_masking, 0.0)
|
||||
dv_pt.masked_fill_(zero_masking, 0.0)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
|
||||
@ -5,16 +5,23 @@ from einops import rearrange, repeat
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
|
||||
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
|
||||
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
|
||||
assert mode in ["full", "random", "third"]
|
||||
if mode == "full":
|
||||
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
|
||||
elif mode == "random":
|
||||
lengths = torch.randint(
|
||||
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
|
||||
)
|
||||
elif mode == "third":
|
||||
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
|
||||
|
||||
if zero_lengths:
|
||||
# Generate zero-lengths every 5 batches and the last batch.
|
||||
for i in range(batch_size):
|
||||
if i % 5 == 0:
|
||||
lengths[i] = 0
|
||||
lengths[-1] = 0
|
||||
padding_mask = (
|
||||
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
|
||||
)
|
||||
@ -251,4 +258,5 @@ def attention_ref(
|
||||
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
|
||||
if query_padding_mask is not None:
|
||||
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
|
||||
output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user