From 731f154de32834c474ae28f0173f2a36f79f4667 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 1 Nov 2022 14:09:22 -0700 Subject: [PATCH] Fix race conditions in the Triton bwd for headdim=64 --- flash_attn/flash_attn_triton.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 9ec45c7..54ccf7c 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -42,7 +42,6 @@ import triton.language as tl { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, - # "EVEN_N": lambda args: False, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @@ -86,8 +85,8 @@ def _fwd_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # load q: it will stay in SRAM throughout - # [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call - # tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler? + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) @@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block( v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) - # initialize dv amd dk + # initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # k and v stay in SRAM throughout @@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block( if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. - if not EVEN_HEADDIM: + # Also wrong for headdim=64. + if not EVEN_M: tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) p = tl.exp(qk * softmax_scale - lse_i[:, None]) @@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block( # the output is correct. if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) # if EVEN_M: # if EVEN_HEADDIM: # do = tl.load(do_ptrs) # else: # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) - else: - if EVEN_HEADDIM: - do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) - else: - do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) - & (offs_d[None, :] < headdim), other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False if not EVEN_M: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) @@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block( # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - if EVEN_N: + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) @@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) # dq_accum = torch.zeros_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse)