Add debug_barrier for all headdims in Triton bwd

This commit is contained in:
Tri Dao 2022-10-31 01:25:02 -07:00
parent bedcbd6a71
commit 4f81aff46e
2 changed files with 6 additions and 4 deletions

View File

@ -300,8 +300,8 @@ def _bwd_kernel_one_col_block(
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
if not EVEN_HEADDIM:
tl.debug_barrier()
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster

View File

@ -896,11 +896,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.