Fix race conditions in the Triton bwd for headdim=64

This commit is contained in:
Tri Dao 2022-11-01 14:09:22 -07:00
parent 9b0bc97872
commit 731f154de3

View File

@ -42,7 +42,6 @@ import triton.language as tl
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@ -86,8 +85,8 @@ def _fwd_kernel(
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block(
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
# initialize dv amd dk
# initialize dv and dk
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout
@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block(
if IS_CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
if not EVEN_HEADDIM:
# Also wrong for headdim=64.
if not EVEN_M:
tl.debug_barrier()
lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(qk * softmax_scale - lse_i[:, None])
@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block(
# the output is correct.
if EVEN_M & EVEN_HEADDIM:
do = tl.load(do_ptrs)
else:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# if EVEN_M:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# else:
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
else:
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
& (offs_d[None, :] < headdim), other=0.0)
# else:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
# else:
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
# & (offs_d[None, :] < headdim), other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
if not EVEN_M:
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block(
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
if EVEN_N:
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)
@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
assert d <= 128
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded)
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
dq_accum = torch.empty_like(q, dtype=torch.float32)
delta = torch.empty_like(lse)