Fix the varlen deterministic test (#1023)
Co-authored-by: moshuosha <moshuosha@qq.com>
This commit is contained in:
parent
9486635c92
commit
6df7e0a02e
@ -2459,9 +2459,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
|
||||
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
||||
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
||||
for _ in range(50):
|
||||
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
|
||||
assert torch.equal(dv, dv)
|
||||
assert torch.equal(dk, dk)
|
||||
assert torch.equal(dq, dq)
|
||||
assert torch.equal(dv, dv0)
|
||||
assert torch.equal(dk, dk0)
|
||||
assert torch.equal(dq, dq0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user