Fix race condition in Triton bwd when there's bias

This commit is contained in:
Tri Dao 2022-11-04 11:20:27 -07:00
parent 86862cfd7b
commit ff78ea4123
2 changed files with 3 additions and 1 deletions

View File

@ -326,6 +326,7 @@ def _bwd_kernel_one_col_block(
if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
if BIAS_TYPE != 'none':
tl.debug_barrier() # Race condition otherwise
if BIAS_TYPE == 'vector':
if EVEN_N:
bias = tl.load(b_ptrs).to(tl.float32)

View File

@ -976,7 +976,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol)
# Run 10000 times and check that the results don't change
for i in range(10000):
output = flash_attn_func(q, k, v, None, causal)
output = flash_attn_func(q, k, v, bias, causal)
output_equal = torch.equal(output, output_0)
if not output_equal: # Printing / computing diff sometimes makes the race condition disappear
print(f'Output max diff: {(output - output_0).abs().max().item()}')
@ -986,6 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype,
dk_equal = torch.equal(dk, dk_0)
dv_equal = torch.equal(dv, dv_0)
if not (dq_equal and dk_equal and dv_equal):
print(f'{i = }')
print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
print(f'dV max diff: {(dv - dv_0).abs().max().item()}')