From b910bf14c1baa7e6a4886c1cd07d65e7a61390c0 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 30 Oct 2022 21:50:53 -0700 Subject: [PATCH] Support arbitrary seqlens (both q & k) in Triton bwd --- flash_attn/flash_attn_triton.py | 41 ++++++++++++++++++++++----------- tests/test_flash_attn.py | 33 ++++++++++++-------------- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index 7e846b8..593e1e0 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -5,10 +5,8 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention Changes: - Implement both causal and non-causal attention. - Implement cross-attention (not just self-attention). -- Support arbitrary seqlens (not just multiples of 128) in the forward pass. -- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q -must still be a multiple of 128. -- Speed up the forward pass a bit (and only store the LSE instead of m and l). +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. @@ -18,8 +16,6 @@ import math import torch -from einops import rearrange - import triton import triton.language as tl @@ -213,7 +209,9 @@ def _bwd_kernel_one_col_block( dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # k and v stay in SRAM throughout - if EVEN_N: + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False, + # if we just call # tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: @@ -225,7 +223,10 @@ def _bwd_kernel_one_col_block( start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m # load q, k, v, do on-chip - q = tl.load(q_ptrs) + if EVEN_M: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) # recompute p = softmax(qk, dim=-1).T qk = tl.dot(q, k, trans_b=True) if not EVEN_N: # Need to mask out otherwise the softmax is wrong @@ -235,7 +236,10 @@ def _bwd_kernel_one_col_block( lse_i = tl.load(LSE + offs_m_curr) p = tl.exp(qk * softmax_scale - lse_i[:, None]) # compute dv - do = tl.load(do_ptrs) + if EVEN_M: + do = tl.load(do_ptrs) + else: + do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) dp = tl.dot(do, v, trans_b=True) @@ -249,12 +253,22 @@ def _bwd_kernel_one_col_block( dk += tl.dot(ds, q, trans_a=True) # compute dq if not ATOMIC_ADD: - dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds, k) - tl.store(dq_ptrs, dq, eviction_policy="evict_last") + if EVEN_M: + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, + eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last") else: # If we're parallelizing across the seqlen_k dimension dq = tl.dot(ds, k) - tl.atomic_add(dq_ptrs, dq) + if EVEN_M: + tl.atomic_add(dq_ptrs, dq) + else: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) # increment pointers dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm @@ -417,7 +431,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_ do = do.contiguous() batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape - assert seqlen_q % 128 == 0, 'Backward pass currently only supports seqlens that are multiples of 128' seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) # dq_accum = torch.zeros_like(q, dtype=torch.float32) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index dc72817..b2d10ba 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) # @pytest.mark.parametrize('dtype', [torch.float16]) @pytest.mark.parametrize('causal', [False, True]) -# @pytest.mark.parametrize('causal', [True]) +# @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('d', [64, 128]) # @pytest.mark.parametrize('d', [64]) # @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)]) -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 211)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 128)]) def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: pytest.skip() # Reference implementation OOM @@ -887,25 +887,22 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') - run_bwd = seqlen_q % 128 == 0 - if run_bwd: - g = torch.randn_like(output) - dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) - dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) - dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) - print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') - print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') - print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') - print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') - print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') - print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') + g = torch.randn_like(output) + dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) + dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) + dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) + print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') + print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') + print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') + print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') + print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') + print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) - if run_bwd: - assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() - assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() - assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()