From 6df7e0a02edcee851744168079377a039f6d728d Mon Sep 17 00:00:00 2001 From: muoshuosha <54895915+muoshuosha@users.noreply.github.com> Date: Thu, 4 Jul 2024 02:07:57 +0800 Subject: [PATCH] Fix the varlen deterministic test (#1023) Co-authored-by: moshuosha --- tests/test_flash_attn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ba0e255..9d9b42f 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -2459,9 +2459,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): - dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) - assert torch.equal(dv, dv) - assert torch.equal(dk, dk) - assert torch.equal(dq, dq) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0)