From 470010f59b00a53f371e602cfcfd9bb2919f140d Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 3 Nov 2022 15:26:53 -0700 Subject: [PATCH] Fix race condition for Triton bwd for headdim 48 and 96 --- flash_attn/flash_attn_triton.py | 41 +++++++++++++++------------------ 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 2fcb721..3f3507f 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -4,21 +4,26 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention Changes: - Implement both causal and non-causal attention. -- Implement cross-attention (not just self-attention). +- Implement both self-attention and cross-attention. - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. -- [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass -and backward pass. For the backward pass, head dims that are not 64, 128 will require -more testing since there seems to be some race conditions due to the Triton compiler. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. +Caution: +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward. -- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly -slower when headdim=128 and batch * nheads is large. +- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. +It is slightly slower when headdim=128 and batch * nheads is large. - Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). """ @@ -276,6 +281,7 @@ def _bwd_kernel_one_col_block( & (offs_d[None, :] < headdim), other=0.0) # recompute p = softmax(qk, dim=-1).T qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) if IS_CAUSAL: @@ -313,7 +319,7 @@ def _bwd_kernel_one_col_block( # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False - if not EVEN_M: + if not (EVEN_M & EVEN_HEADDIM): tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) # There's a race condition for headdim=48 @@ -329,16 +335,10 @@ def _bwd_kernel_one_col_block( dk += tl.dot(ds, q, trans_a=True) # compute dq if not ATOMIC_ADD: - if EVEN_M: - if EVEN_HEADDIM: - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") - else: - dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0, - eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last") + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") else: if EVEN_HEADDIM: dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, @@ -356,11 +356,8 @@ def _bwd_kernel_one_col_block( eviction_policy="evict_last") else: # If we're parallelizing across the seqlen_k dimension dq = tl.dot(ds, k) - if EVEN_M: - if EVEN_HEADDIM: - tl.atomic_add(dq_ptrs, dq) - else: - tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) else: if EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)