diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 489f56d..7e205ce 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention Changes: - Support both causal and non-causal attention. +- Support arbitrary seqlens (not just multiples of 128) in the forward pass. - Speed up the forward pass a bit (and only store the LSE instead of m and l). - Make the backward for d=128 much faster by reducing register spilling. - Add the option to parallelize the backward pass across seqlen_k, to deal with the case of @@ -30,7 +31,7 @@ import triton.language as tl @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, - "EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, } ) @triton.jit @@ -42,7 +43,7 @@ def _fwd_kernel( stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om, - nheads, seqlen_q, seqlen_k, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, @@ -68,12 +69,14 @@ def _fwd_kernel( k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) # initialize pointer to m and l - t_ptrs = TMP + off_hb * seqlen_q + offs_m + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # load q: it will stay in SRAM throughout - if EVEN_M: + # [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler? + if EVEN_M & EVEN_N: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) @@ -130,7 +133,7 @@ def _fwd_kernel( start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m - lse_ptrs = Lse + off_hb * seqlen_q + offs_m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) # initialize pointers to output offs_n = tl.arange(0, BLOCK_HEADDIM) @@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), o.stride(0), o.stride(2), o.stride(1), - nheads, seqlen_q, seqlen_k, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index ae32668..4713f11 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -855,15 +855,17 @@ def test_flash_attn_multigpu(): from flash_attn.flash_attn_triton import flash_attn_func + +@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100') @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) -# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +# @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [64, 128]) # @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) -@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(127, 256)]) def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM @@ -885,22 +887,25 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - g = torch.randn_like(output) - dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) - dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') - print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') - print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') + run_bwd = (seqlen_q % 128 == 0) and (seqlen_k % 128 == 0) + if run_bwd: + g = torch.randn_like(output) + dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) + dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) + dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) + print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') + print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') + print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') + print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') + print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') + print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + if run_bwd: + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()