Don't enforce bitwise consistency for dq in race condition test

Since we could be parallelizing over seqlen_k
This commit is contained in:
Tri Dao 2022-11-13 12:21:03 -08:00
parent 7c9953815a
commit 9d3116addf

View File

@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
g = torch.randn_like(output_unpad_0)
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
(q_unpad, k_unpad, v_unpad), g)
# Parallelizing over seqlen_k makes dq non-deterministic
deterministic_dq = False
# Numerical error if we just do any arithmetic on dq
dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item()
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
for _ in range(10):
torch.random.manual_seed(0)
@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
(q_unpad, k_unpad, v_unpad), g)
assert torch.equal(dq_unpad, dq_unpad_0)
assert equal_fn(dq_unpad, dq_unpad_0)
assert torch.equal(dk_unpad, dk_unpad_0)
assert torch.equal(dv_unpad, dv_unpad_0)