From 7479757191c04cc1d5a029b0b34c5064278c93ef Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 6 Nov 2022 11:46:55 -0800 Subject: [PATCH] Fix pipelining bug in Triton bwd with bias_type=matrix --- flash_attn/flash_attn_triton.py | 52 +++++++++++++++++++++++---------- tests/test_flash_attn.py | 20 ++++++------- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 53100fa..274a083 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -18,6 +18,7 @@ small batch size * nheads. Caution: - This is an *experimental* implementation. The forward pass should be quite robust but I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. - If you plan to use headdim other than 64 and 128, you should test for race conditions (due to the Triton compiler), as done in tests/test_flash_attn.py "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions @@ -250,6 +251,29 @@ def _bwd_preprocess_do_o_dot( tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + @triton.jit def _bwd_kernel_one_col_block( start_n, @@ -287,6 +311,16 @@ def _bwd_kernel_one_col_block( # initialize dv and dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + return # k and v stay in SRAM throughout # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, # if we just call tl.load(k_ptrs), we get the wrong output! @@ -437,22 +471,8 @@ def _bwd_kernel_one_col_block( # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) - # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, - # if we just call tl.store(dv_ptrs), there's a race condition - if EVEN_N & EVEN_M: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv) - tl.store(dk_ptrs, dk) - else: - tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) - tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) - else: - if EVEN_HEADDIM: - tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) - tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) - else: - tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) - tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) def init_to_zero(name): diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index e4975a9..b85da45 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -864,14 +864,13 @@ from flash_attn.flash_attn_triton import flash_attn_func @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [64]) -# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +# @pytest.mark.parametrize('d', [48]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1023)]) @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk'])) -# @pytest.mark.parametrize('bias_shape', (['1h1k'])) +# @pytest.mark.parametrize('bias_shape', (['1hqk'])) def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_shape): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM @@ -935,13 +934,13 @@ def test_flash_attn_triton_output(seqlen_q, seqlen_k, d, causal, dtype, bias_sha @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [False]) +# @pytest.mark.parametrize('causal', [True]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [96]) -# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +# @pytest.mark.parametrize('d', [64]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203)]) @pytest.mark.parametrize('bias_shape', ([None, '1h1k', '1hqk', 'b11k', 'b1qk'])) +# @pytest.mark.parametrize('bias_shape', (['b1qk'])) def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, bias_shape): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM @@ -979,6 +978,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, output = flash_attn_func(q, k, v, bias, causal) output_equal = torch.equal(output, output_0) if not output_equal: # Printing / computing diff sometimes makes the race condition disappear + print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }') print(f'Output max diff: {(output - output_0).abs().max().item()}') assert torch.equal(output, output_0) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) @@ -986,7 +986,7 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype, dk_equal = torch.equal(dk, dk_0) dv_equal = torch.equal(dv, dv_0) if not (dq_equal and dk_equal and dv_equal): - print(f'{i = }') + print(f'{dtype = }, {causal = }, {d = }, {seqlen_q = }, {seqlen_k = }, {bias_shape = }, {i = }') print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') print(f'dK max diff: {(dk - dk_0).abs().max().item()}') print(f'dV max diff: {(dv - dv_0).abs().max().item()}')