diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 3b212d4..2fcb721 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -7,7 +7,7 @@ Changes: - Implement cross-attention (not just self-attention). - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass -and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require +and backward pass. For the backward pass, head dims that are not 64, 128 will require more testing since there seems to be some race conditions due to the Triton compiler. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. @@ -17,9 +17,9 @@ small batch size * nheads. Differences between this Triton version and the CUDA version: - Triton version doesn't support dropout. - Triton forward is generally faster than CUDA forward. -- Triton backward is faster than CUDA backward when batch * nheads is small, and might be slightly -slower in other cases. -- Triton version does yet not support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. It is slightly +slower when headdim=128 and batch * nheads is large. +- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). """ import math @@ -282,7 +282,7 @@ def _bwd_kernel_one_col_block( qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. # Also wrong for headdim=64. - if not EVEN_M: + if not (EVEN_M & EVEN_HEADDIM): tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) p = tl.exp(qk * softmax_scale - lse_i[:, None]) @@ -316,6 +316,9 @@ def _bwd_kernel_one_col_block( if not EVEN_M: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() # compute ds = p * (dp - delta[:, None]) # Putting the subtraction after the dp matmul (instead of before) is slightly faster Di = tl.load(D + offs_m_curr) @@ -390,10 +393,6 @@ def _bwd_kernel_one_col_block( def init_to_zero(name): - # def fn(nargs): - # with torch.no_grad(): - # nargs[name].zero_() - # return fn return lambda nargs: nargs[name].zero_() @triton.autotune( @@ -406,15 +405,8 @@ def init_to_zero(name): # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), - # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), - # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), ], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'], - # reset_to_zero=['DQ'] ) @triton.heuristics( { diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 4a91458..1ce3837 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -1,4 +1,5 @@ import math +from functools import partial import torch import torch.nn.functional as F @@ -858,14 +859,14 @@ from flash_attn.flash_attn_triton import flash_attn_func @pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('causal', [False, True]) # @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [40]) +# @pytest.mark.parametrize('d', [48]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1024, 1024)]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1023)]) def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM @@ -916,13 +917,13 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): @pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.float16]) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) @pytest.mark.parametrize('causal', [False, True]) # @pytest.mark.parametrize('causal', [True]) -@pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) -# @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize('d', [40, 48, 64, 128, 80, 88, 96]) +@pytest.mark.parametrize('d', [64, 128]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (2048, 2048)]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (91, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (1023, 1024), (1024, 1023), (2048, 2048)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(1023, 1024)]) def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: @@ -941,7 +942,10 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): g = torch.randn_like(output_0) dq_0, dk_0, dv_0 = torch.autograd.grad(output_0, (q, k, v), g) - # Disable the SEQUENCE_PARALLEL option for the bwd to make sure it's deterministic + # The SEQUENCE_PARALLEL option for the bwd to makes dq non-deterministic + deterministic_dq = False + equal_fn = (torch.equal if deterministic_dq + else partial(torch.allclose, atol=1e-3 if dtype == torch.bfloat16 else 1e-5)) for i in range(10000): output = flash_attn_func(q, k, v, causal) output_equal = torch.equal(output, output_0) @@ -949,13 +953,13 @@ def test_flash_attn_triton_race_condition(seqlen_q, seqlen_k, d, causal, dtype): print(f'Output max diff: {(output - output_0).abs().max().item()}') assert torch.equal(output, output_0) dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - dq_equal = torch.equal(dq, dq_0) + dq_equal = equal_fn(dq, dq_0) dk_equal = torch.equal(dk, dk_0) dv_equal = torch.equal(dv, dv_0) if not (dq_equal and dk_equal and dv_equal): print(f'dQ max diff: {(dq - dq_0).abs().max().item()}') print(f'dK max diff: {(dk - dk_0).abs().max().item()}') print(f'dV max diff: {(dv - dv_0).abs().max().item()}') - assert torch.equal(dq, dq_0) + assert equal_fn(dq, dq_0) assert torch.equal(dk, dk_0) assert torch.equal(dv, dv_0)