From 9d3116addf00e585eaf3f249a6543f5970365288 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 13 Nov 2022 12:21:03 -0800 Subject: [PATCH] Don't enforce bitwise consistency for dq in race condition test Since we could be parallelizing over seqlen_k --- tests/test_flash_attn.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 0fc614b..d27344e 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): g = torch.randn_like(output_unpad_0) dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0, (q_unpad, k_unpad, v_unpad), g) + # Parallelizing over seqlen_k makes dq non-deterministic + deterministic_dq = False + # Numerical error if we just do any arithmetic on dq + dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item() + equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol) for _ in range(10): torch.random.manual_seed(0) @@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype): if is_sm80 or d <= 64: # Only run backward for d=128 on A100 dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad, (q_unpad, k_unpad, v_unpad), g) - assert torch.equal(dq_unpad, dq_unpad_0) + assert equal_fn(dq_unpad, dq_unpad_0) assert torch.equal(dk_unpad, dk_unpad_0) assert torch.equal(dv_unpad, dv_unpad_0)