Fix race conditions in the Triton bwd for headdim=64

This commit is contained in:
Tri Dao 2022-11-01 14:09:22 -07:00
parent 9b0bc97872
commit 731f154de3

View File

@ -42,7 +42,6 @@ import triton.language as tl
{ {
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
} }
) )
@ -86,8 +85,8 @@ def _fwd_kernel(
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout # load q: it will stay in SRAM throughout
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler? # tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N: if EVEN_M & EVEN_N:
if EVEN_HEADDIM: if EVEN_HEADDIM:
q = tl.load(q_ptrs) q = tl.load(q_ptrs)
@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block(
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
# initialize dv amd dk # initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout # k and v stay in SRAM throughout
@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block(
if IS_CAUSAL: if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
if not EVEN_HEADDIM: # Also wrong for headdim=64.
if not EVEN_M:
tl.debug_barrier() tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr) lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(qk * softmax_scale - lse_i[:, None]) p = tl.exp(qk * softmax_scale - lse_i[:, None])
@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block(
# the output is correct. # the output is correct.
if EVEN_M & EVEN_HEADDIM: if EVEN_M & EVEN_HEADDIM:
do = tl.load(do_ptrs) do = tl.load(do_ptrs)
else:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# if EVEN_M: # if EVEN_M:
# if EVEN_HEADDIM: # if EVEN_HEADDIM:
# do = tl.load(do_ptrs) # do = tl.load(do_ptrs)
# else: # else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else: # else:
if EVEN_HEADDIM: # if EVEN_HEADDIM:
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else: # else:
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0) # & (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True) dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do) # compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if not EVEN_M: if not EVEN_M:
tl.debug_barrier() tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True) dp = tl.dot(do, v, trans_b=True)
@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block(
# write-back # write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
if EVEN_N: # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if EVEN_N & EVEN_M:
if EVEN_HEADDIM: if EVEN_HEADDIM:
tl.store(dv_ptrs, dv) tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk) tl.store(dk_ptrs, dk)
@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
assert d <= 128 assert d <= 128
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded) assert lse.shape == (batch, nheads, seqlen_q_rounded)
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
# dq_accum = torch.zeros_like(q, dtype=torch.float32) # dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32)
delta = torch.empty_like(lse) delta = torch.empty_like(lse)