add a unittest

This commit is contained in:
Ying Zhang 2024-08-17 13:23:50 -07:00
parent a3a257c71d
commit 53537da422
2 changed files with 17 additions and 4 deletions

View File

@ -236,8 +236,8 @@ def test_flash_attn_varlen_output(
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True
)
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random")
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random")
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random", zero_lengths=False)
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random", zero_lengths=True)
# key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode='full')
(
@ -312,11 +312,16 @@ def test_flash_attn_varlen_output(
dk_ref,
dv_ref,
) = torch.autograd.grad(out_ref, (q, k, v), g)
zero_masking = rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1")
dk_ref.masked_fill_(zero_masking, 0.0)
dv_ref.masked_fill_(zero_masking, 0.0)
(
dq_pt,
dk_pt,
dv_pt,
) = torch.autograd.grad(out_pt, (q, k, v), g)
dk_pt.masked_fill_(zero_masking, 0.0)
dv_pt.masked_fill_(zero_masking, 0.0)
dq = dq_pad_fn(dq_unpad)
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")

View File

@ -5,16 +5,23 @@ from einops import rearrange, repeat
from flash_attn.bert_padding import pad_input, unpad_input
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"):
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random", zero_lengths=False):
assert mode in ["full", "random", "third"]
if mode == "full":
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32)
elif mode == "random":
lengths = torch.randint(
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
max(0 if zero_lengths else 1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device
)
elif mode == "third":
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device)
if zero_lengths:
# Generate zero-lengths every 5 batches and the last batch.
for i in range(batch_size):
if i % 5 == 0:
lengths[i] = 0
lengths[-1] = 0
padding_mask = (
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths
)
@ -251,4 +258,5 @@ def attention_ref(
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling)
if query_padding_mask is not None:
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
output.masked_fill_(rearrange(torch.logical_not(torch.any(key_padding_mask, 1)), "b -> b 1 1 1"), 0.0)
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)