Fix race condition in Triton fwd
This commit is contained in:
parent
215930bce3
commit
9b0bc97872
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Based on the FlashAttention implementation from Phil Tillet.
|
||||
We use the FlashAttention implementation from Phil Tillet a starting point.
|
||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
|
||||
Changes:
|
||||
@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
small batch size * nheads.
|
||||
|
||||
Differences between this Triton version and the CUDA version:
|
||||
- Triton version doesn't support dropout.
|
||||
- Triton forward is generally faster than CUDA forward.
|
||||
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
|
||||
slower in other cases.
|
||||
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
|
||||
"""
|
||||
|
||||
import math
|
||||
@ -26,7 +33,8 @@ import triton.language as tl
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
||||
# This config has a race condition when EVEN_M == False, disabling it for now.
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
||||
],
|
||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
|
||||
)
|
||||
@ -34,6 +42,7 @@ import triton.language as tl
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
# "EVEN_N": lambda args: False,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@ -95,7 +104,7 @@ def _fwd_kernel(
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
if EVEN_N:
|
||||
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
@ -129,7 +138,7 @@ def _fwd_kernel(
|
||||
acc_o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * acc_o_scale[:, None]
|
||||
# update acc_o
|
||||
if EVEN_N:
|
||||
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
|
||||
# compute dp = dot(v, do)
|
||||
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
||||
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
||||
tl.debug_barrier()
|
||||
if not EVEN_M:
|
||||
tl.debug_barrier()
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||
|
||||
@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
|
||||
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
device = 'cuda'
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 32
|
||||
nheads = 4
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
|
||||
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
|
||||
|
||||
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
|
||||
output_0 = flash_attn_func(q, k, v, causal)
|
||||
|
||||
g = torch.randn_like(output_0)
|
||||
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
|
||||
|
||||
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
|
||||
for i in range(10000):
|
||||
output = flash_attn_func(q, k, v, causal)
|
||||
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
|
||||
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
|
||||
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
|
||||
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
|
||||
assert torch.equal(output, output_0)
|
||||
# assert torch.equal(dq, dq_0)
|
||||
# assert torch.equal(dk, dk_0)
|
||||
# assert torch.equal(dv, dv_0)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user