From 9b0bc978729d5470aaf5667b8f65df6fa2b3a007 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Mon, 31 Oct 2022 14:34:22 -0700 Subject: [PATCH] Fix race condition in Triton fwd --- flash_attn/flash_attn_triton.py | 20 ++++++++++++---- tests/test_flash_attn.py | 41 +++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 6376605..9ec45c7 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -1,5 +1,5 @@ """ -Based on the FlashAttention implementation from Phil Tillet. +We use the FlashAttention implementation from Phil Tillet a starting point. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: @@ -13,6 +13,13 @@ more testing since there seems to be some race conditions due to the Triton comp - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward. +- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly +slower in other cases. +- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). """ import math @@ -26,7 +33,8 @@ import triton.language as tl @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), + # This config has a race condition when EVEN_M == False, disabling it for now. + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), ], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] ) @@ -34,6 +42,7 @@ import triton.language as tl { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + # "EVEN_N": lambda args: False, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @@ -95,7 +104,7 @@ def _fwd_kernel( for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - if EVEN_N: + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: @@ -129,7 +138,7 @@ def _fwd_kernel( acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] # update acc_o - if EVEN_N: + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: @@ -299,7 +308,8 @@ def _bwd_kernel_one_col_block( # compute dp = dot(v, do) # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True - tl.debug_barrier() + if not EVEN_M: + tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) # compute ds = p * (dp - delta[:, None]) # Putting the subtraction after the dp matmul (instead of before) is slightly faster diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index f64f190..aa7b398 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -912,3 +912,44 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + + +@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') +@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) +# @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)]) +def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): + if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: + pytest.skip() # Reference implementation OOM + device = 'cuda' + # set seed + torch.random.manual_seed(0) + batch_size = 32 + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) + + q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] + output_0 = flash_attn_func(q, k, v, causal) + + g = torch.randn_like(output_0) + dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g) + + # Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic + for i in range(10000): + output = flash_attn_func(q, k, v, causal) + # print(f'Output max diff: {(output - output_0).abs().max().item()}') + # dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) + # print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') + # print(f'dK max diff: {(dk - dk_0).abs().max().item()}') + # print(f'dV max diff: {(dv - dv_0).abs().max().item()}') + assert torch.equal(output, output_0) + # assert torch.equal(dq, dq_0) + # assert torch.equal(dk, dk_0) + # assert torch.equal(dv, dv_0)