Fix the varlen deterministic test (#1023)

Co-authored-by: moshuosha <moshuosha@qq.com>
This commit is contained in:
muoshuosha 2024-07-04 02:07:57 +08:00 committed by GitHub
parent 9486635c92
commit 6df7e0a02e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2459,9 +2459,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
g = torch.randn_like(out)
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
assert torch.equal(dv, dv)
assert torch.equal(dk, dk)
assert torch.equal(dq, dq)
assert torch.equal(dv, dv0)
assert torch.equal(dk, dk0)
assert torch.equal(dq, dq0)