Fix race condition in Triton fwd

This commit is contained in:
Tri Dao 2022-10-31 14:34:22 -07:00
parent 215930bce3
commit 9b0bc97872
2 changed files with 56 additions and 5 deletions

View File

@ -1,5 +1,5 @@
"""
Based on the FlashAttention implementation from Phil Tillet.
We use the FlashAttention implementation from Phil Tillet a starting point.
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
Changes:
@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
Differences between this Triton version and the CUDA version:
- Triton version doesn't support dropout.
- Triton forward is generally faster than CUDA forward.
- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly
slower in other cases.
- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
"""
import math
@ -26,7 +33,8 @@ import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
# This config has a race condition when EVEN_M == False, disabling it for now.
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
],
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
)
@ -34,6 +42,7 @@ import triton.language as tl
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
# "EVEN_N": lambda args: False,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@ -95,7 +104,7 @@ def _fwd_kernel(
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
if EVEN_N:
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
@ -129,7 +138,7 @@ def _fwd_kernel(
acc_o_scale = tl.load(t_ptrs)
acc_o = acc_o * acc_o_scale[:, None]
# update acc_o
if EVEN_N:
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block(
# compute dp = dot(v, do)
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
tl.debug_barrier()
if not EVEN_M:
tl.debug_barrier()
dp = tl.dot(do, v, trans_b=True)
# compute ds = p * (dp - delta[:, None])
# Putting the subtraction after the dp matmul (instead of before) is slightly faster

View File

@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)])
def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
nheads = 4
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
output_0 = flash_attn_func(q, k, v, causal)
g = torch.randn_like(output_0)
dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g)
# Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic
for i in range(10000):
output = flash_attn_func(q, k, v, causal)
# print(f'Output max diff: {(output - output_0).abs().max().item()}')
# dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
# print(f'dQ max diff: {(dq - dq_0).abs().max().item()}')
# print(f'dK max diff: {(dk - dk_0).abs().max().item()}')
# print(f'dV max diff: {(dv - dv_0).abs().max().item()}')
assert torch.equal(output, output_0)
# assert torch.equal(dq, dq_0)
# assert torch.equal(dk, dk_0)
# assert torch.equal(dv, dv_0)