Fix race condition for Triton bwd for headdim 48 and 96

This commit is contained in:
Tri Dao 2022-11-03 15:26:53 -07:00
parent aacc10fbab
commit 470010f59b

View File

@ -4,21 +4,26 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Implement both self-attention and cross-attention.
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
and backward pass. For the backward pass, head dims that are not 64, 128 will require
more testing since there seems to be some race conditions due to the Triton compiler.
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
Caution:
- If you plan to use headdim other than 64 and 128, you should test for race conditions
(due to the Triton compiler), as done in tests/test_flash_attn.py
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
that there are none left for other head dimensions.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly
slower when headdim=128 and batch * nheads is large.
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
It is slightly slower when headdim=128 and batch * nheads is large.
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
@ -276,6 +281,7 @@ def _bwd_kernel_one_col_block(
& (offs_d[None, :] < headdim), other=0.0)
# recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True)
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
if IS_CAUSAL:
@ -313,7 +319,7 @@ def _bwd_kernel_one_col_block(
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if not EVEN_M:
if not (EVEN_M & EVEN_HEADDIM):
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
# There's a race condition for headdim=48
@ -329,16 +335,10 @@ def _bwd_kernel_one_col_block(
dk += tl.dot(ds, q, trans_a=True)
# compute dq
if not ATOMIC_ADD:
if EVEN_M:
if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0,
eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last")
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
if EVEN_HEADDIM:
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
@ -356,11 +356,8 @@ def _bwd_kernel_one_col_block(
eviction_policy="evict_last")
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
if EVEN_M:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq)
else:
tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
tl.atomic_add(dq_ptrs, dq)
else:
if EVEN_HEADDIM:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)