Fix race conditions in the Triton bwd for headdim=64
This commit is contained in:
parent
9b0bc97872
commit
731f154de3
@ -42,7 +42,6 @@ import triton.language as tl
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
# "EVEN_N": lambda args: False,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@ -86,8 +85,8 @@ def _fwd_kernel(
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
|
||||
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
|
||||
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
|
||||
# tl.load(q_ptrs), we get the wrong output!
|
||||
if EVEN_M & EVEN_N:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
@ -238,7 +237,7 @@ def _bwd_kernel_one_col_block(
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
|
||||
# initialize dv amd dk
|
||||
# initialize dv and dk
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
@ -282,7 +281,8 @@ def _bwd_kernel_one_col_block(
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
|
||||
if not EVEN_HEADDIM:
|
||||
# Also wrong for headdim=64.
|
||||
if not EVEN_M:
|
||||
tl.debug_barrier()
|
||||
lse_i = tl.load(LSE + offs_m_curr)
|
||||
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
||||
@ -293,21 +293,26 @@ def _bwd_kernel_one_col_block(
|
||||
# the output is correct.
|
||||
if EVEN_M & EVEN_HEADDIM:
|
||||
do = tl.load(do_ptrs)
|
||||
else:
|
||||
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
|
||||
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
& (offs_d[None, :] < headdim), other=0.0)
|
||||
# if EVEN_M:
|
||||
# if EVEN_HEADDIM:
|
||||
# do = tl.load(do_ptrs)
|
||||
# else:
|
||||
# do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
else:
|
||||
do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
& (offs_d[None, :] < headdim), other=0.0)
|
||||
# else:
|
||||
# if EVEN_HEADDIM:
|
||||
# do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
# else:
|
||||
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
||||
# & (offs_d[None, :] < headdim), other=0.0)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||
# Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
|
||||
if not EVEN_M:
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
@ -366,7 +371,9 @@ def _bwd_kernel_one_col_block(
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
|
||||
if EVEN_N:
|
||||
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
|
||||
# if we just call tl.store(dv_ptrs), there's a race condition
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
@ -536,6 +543,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
|
||||
assert d <= 128
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
||||
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
||||
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
||||
delta = torch.empty_like(lse)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user