Fix race condition in Triton fwd
This commit is contained in:
parent
215930bce3
commit
9b0bc97872
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Based on the FlashAttention implementation from Phil Tillet.
|
We use the FlashAttention implementation from Phil Tillet a starting point.
|
||||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||||
|
|
||||||
Changes:
|
Changes:
|
||||||
@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
|
|||||||
- Make the backward for d=128 much faster by reducing register spilling.
|
- Make the backward for d=128 much faster by reducing register spilling.
|
||||||
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
||||||
small batch size * nheads.
|
small batch size * nheads.
|
||||||
|
|
||||||
|
Differences between this Triton version and the CUDA version:
|
||||||
|
- Triton version doesn't support dropout.
|
||||||
|
- Triton forward is generally faster than CUDA forward.
|
||||||
|
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
|
||||||
|
slower in other cases.
|
||||||
|
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
@ -26,7 +33,8 @@ import triton.language as tl
|
|||||||
@triton.autotune(
|
@triton.autotune(
|
||||||
configs=[
|
configs=[
|
||||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
|
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
|
||||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
# This config has a race condition when EVEN_M == False, disabling it for now.
|
||||||
|
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
||||||
],
|
],
|
||||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
|
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
|
||||||
)
|
)
|
||||||
@ -34,6 +42,7 @@ import triton.language as tl
|
|||||||
{
|
{
|
||||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||||
|
# "EVEN_N": lambda args: False,
|
||||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -95,7 +104,7 @@ def _fwd_kernel(
|
|||||||
for start_n in range(0, end_n, BLOCK_N):
|
for start_n in range(0, end_n, BLOCK_N):
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
if EVEN_N:
|
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||||
if EVEN_HEADDIM:
|
if EVEN_HEADDIM:
|
||||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||||
else:
|
else:
|
||||||
@ -129,7 +138,7 @@ def _fwd_kernel(
|
|||||||
acc_o_scale = tl.load(t_ptrs)
|
acc_o_scale = tl.load(t_ptrs)
|
||||||
acc_o = acc_o * acc_o_scale[:, None]
|
acc_o = acc_o * acc_o_scale[:, None]
|
||||||
# update acc_o
|
# update acc_o
|
||||||
if EVEN_N:
|
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||||
if EVEN_HEADDIM:
|
if EVEN_HEADDIM:
|
||||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||||
else:
|
else:
|
||||||
@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
|
|||||||
# compute dp = dot(v, do)
|
# compute dp = dot(v, do)
|
||||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||||
tl.debug_barrier()
|
if not EVEN_M:
|
||||||
|
tl.debug_barrier()
|
||||||
dp = tl.dot(do, v, trans_b=True)
|
dp = tl.dot(do, v, trans_b=True)
|
||||||
# compute ds = p * (dp - delta[:, None])
|
# compute ds = p * (dp - delta[:, None])
|
||||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||||
|
|||||||
@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
|||||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
|
||||||
|
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||||
|
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||||
|
@pytest.mark.parametrize('causal', [False, True])
|
||||||
|
# @pytest.mark.parametrize('causal', [True])
|
||||||
|
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||||
|
# @pytest.mark.parametrize('d', [64])
|
||||||
|
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||||
|
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
|
||||||
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
|
||||||
|
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||||
|
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||||
|
pytest.skip() # Reference implementation OOM
|
||||||
|
device = 'cuda'
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 32
|
||||||
|
nheads = 4
|
||||||
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
|
||||||
|
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
|
||||||
|
|
||||||
|
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
|
||||||
|
output_0 = flash_attn_func(q, k, v, causal)
|
||||||
|
|
||||||
|
g = torch.randn_like(output_0)
|
||||||
|
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
|
||||||
|
|
||||||
|
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
||||||
|
for i in range(10000):
|
||||||
|
output = flash_attn_func(q, k, v, causal)
|
||||||
|
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||||
|
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||||
|
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||||
|
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||||
|
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||||
|
assert torch.equal(output, output_0)
|
||||||
|
# assert torch.equal(dq, dq_0)
|
||||||
|
# assert torch.equal(dk, dk_0)
|
||||||
|
# assert torch.equal(dv, dv_0)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user