Fix race condition for Triton bwd for headdim 48 and 96
This commit is contained in:
parent
aacc10fbab
commit
470010f59b
@ -4,21 +4,26 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
|
||||
|
||||
Changes:
|
||||
- Implement both causal and non-causal attention.
|
||||
- Implement cross-attention (not just self-attention).
|
||||
- Implement both self-attention and cross-attention.
|
||||
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
||||
- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass
|
||||
and backward pass. For the backward pass, head dims that are not 64, 128 will require
|
||||
more testing since there seems to be some race conditions due to the Triton compiler.
|
||||
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward.
|
||||
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
small batch size * nheads.
|
||||
|
||||
Caution:
|
||||
- If you plan to use headdim other than 64 and 128, you should test for race conditions
|
||||
(due to the Triton compiler), as done in tests/test_flash_attn.py
|
||||
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
|
||||
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
|
||||
that there are none left for other head dimensions.
|
||||
|
||||
Differences between this Triton version and the CUDA version:
|
||||
- Triton version doesn't support dropout.
|
||||
- Triton forward is generally faster than CUDA forward.
|
||||
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly
|
||||
slower when headdim=128 and batch * nheads is large.
|
||||
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
|
||||
It is slightly slower when headdim=128 and batch * nheads is large.
|
||||
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||
"""
|
||||
|
||||
@ -276,6 +281,7 @@ def _bwd_kernel_one_col_block(
|
||||
& (offs_d[None, :] < headdim), other=0.0)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
# Trying to combine the two masks seem to make the result wrong
|
||||
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
||||
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
@ -313,7 +319,7 @@ def _bwd_kernel_one_col_block(
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
||||
if not EVEN_M:
|
||||
if not (EVEN_M & EVEN_HEADDIM):
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# There's a race condition for headdim=48
|
||||
@ -329,16 +335,10 @@ def _bwd_kernel_one_col_block(
|
||||
dk += tl.dot(ds, q, trans_a=True)
|
||||
# compute dq
|
||||
if not ATOMIC_ADD:
|
||||
if EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
else:
|
||||
dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0,
|
||||
eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last")
|
||||
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
|
||||
@ -356,11 +356,8 @@ def _bwd_kernel_one_col_block(
|
||||
eviction_policy="evict_last")
|
||||
else: # If we're parallelizing across the seqlen_k dimension
|
||||
dq = tl.dot(ds, k)
|
||||
if EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
else:
|
||||
tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim)
|
||||
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user