From ff78ea4123a29e10dccade3d73b7c4869f383b6a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 4 Nov 2022 11:20:27 -0700 Subject: [PATCH] Fix race condition in Triton bwd when there's bias --- flash_attn/flash_attn_triton.py | 1 + tests/test_flash_attn.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index de9d277..af035f3 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -326,6 +326,7 @@ def _bwd_kernel_one_col_block( if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) if BIAS_TYPE != 'none': + tl.debug_barrier() # Race condition otherwise if BIAS_TYPE == 'vector': if EVEN_N: bias = tl.load(b_ptrs).to(tl.float32) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 94b8661..8434ac4 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -976,7 +976,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, equal_fn = torch.equal if deterministic_dq else partial(torch.allclose, atol=dq_atol) # Run 10000 times and check that the results don't change for i in range(10000): - output = flash_attn_func(q, k, v, None, causal) + output = flash_attn_func(q, k, v, bias, causal) output_equal = torch.equal(output, output_0) if not output_equal: # Printing / computing diff sometimes makes the race condition disappear print(f'Output max diff: {(output - output_0).abs().max().item()}') @@ -986,6 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, dk_equal = torch.equal(dk, dk_0) dv_equal = torch.equal(dv, dv_0) if not (dq_equal and dk_equal and dv_equal): + print(f'{i = }') print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') print(f'dK max diff: {(dk - dk_0).abs().max().item()}') print(f'dV max diff: {(dv - dv_0).abs().max().item()}')