diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ba0e255..9d9b42f 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2459,9 +2459,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): - dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) - assert torch.equal(dv, dv) - assert torch.equal(dk, dk) - assert torch.equal(dq, dq) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0)