Add debug_barrier for all headdims in Triton bwd
This commit is contained in:
parent
bedcbd6a71
commit
4f81aff46e
@ -300,8 +300,8 @@ def _bwd_kernel_one_col_block(
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||
if not EVEN_HEADDIM:
|
||||
tl.debug_barrier()
|
||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||
|
||||
@ -896,11 +896,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ mean diff: {(dq - dq_ref).abs().mean().item()}')
|
||||
print(f'dK mean diff: {(dk - dk_ref).abs().mean().item()}')
|
||||
print(f'dV mean diff: {(dv - dv_ref).abs().mean().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().mean().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().mean().item()}')
|
||||
print(f'dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}')
|
||||
print(f'dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}')
|
||||
print(f'dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user