Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
This commit is contained in:
parent
7c9953815a
commit
9d3116addf
@ -764,6 +764,11 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
g = torch.randn_like(output_unpad_0)
|
||||
dq_unpad_0, dk_unpad_0, dv_unpad_0, = torch.autograd.grad(output_unpad_0,
|
||||
(q_unpad, k_unpad, v_unpad), g)
|
||||
# Parallelizing over seqlen_k makes dq non-deterministic
|
||||
deterministic_dq = False
|
||||
# Numerical error if we just do any arithmetic on dq
|
||||
dq_atol = ((dq_unpad_0 + 0.3 - 0.3) - dq_unpad_0).abs().max().item()
|
||||
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
|
||||
|
||||
for _ in range(10):
|
||||
torch.random.manual_seed(0)
|
||||
@ -782,7 +787,7 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
|
||||
if is_sm80 or d <= 64: # Only run backward for d=128 on A100
|
||||
dq_unpad, dk_unpad, dv_unpad, = torch.autograd.grad(output_unpad,
|
||||
(q_unpad, k_unpad, v_unpad), g)
|
||||
assert torch.equal(dq_unpad, dq_unpad_0)
|
||||
assert equal_fn(dq_unpad, dq_unpad_0)
|
||||
assert torch.equal(dk_unpad, dk_unpad_0)
|
||||
assert torch.equal(dv_unpad, dv_unpad_0)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user