From 4f81aff46e65277d6df30843e98327d4f2571b5f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 31 Oct 2022 01:25:02 -0700 Subject: [PATCH] Add debug_barrier for all headdims in Triton bwd --- flash_attn/flash_attn_triton.py | 4 ++-- tests/test_flash_attn.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 8bda811..26bf419 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -300,8 +300,8 @@ def _bwd_kernel_one_col_block( dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. - if not EVEN_HEADDIM: - tl.debug_barrier() + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) # compute ds = p * (dp - delta[:, None]) # Putting the subtraction after the dp matmul (instead of before) is slightly faster diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index dd54f8c..f64f190 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -896,11 +896,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}') print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}') + print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}') print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().mean().item()}') - print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().mean().item()}') + print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}') + print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}') + print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation.