Support arbitrary seqlens (both q & k) in Triton bwd

This commit is contained in:
Tri Dao 2022-10-30 21:50:53 -07:00
parent dc55469355
commit b910bf14c1
2 changed files with 42 additions and 32 deletions

View File

@ -5,10 +5,8 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
- Implement both causal and non-causal attention.
- Implement cross-attention (not just self-attention).
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q
must still be a multiple of 128.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
- Make the backward for d=128 much faster by reducing register spilling.
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
small batch size * nheads.
@ -18,8 +16,6 @@ import math
import torch
from einops import rearrange
import triton
import triton.language as tl
@ -213,7 +209,9 @@ def _bwd_kernel_one_col_block(
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
# k and v stay in SRAM throughout
if EVEN_N:
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False,
# if we just call # tl.load(k_ptrs), we get the wrong output!
if EVEN_N & EVEN_M:
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
else:
@ -225,7 +223,10 @@ def _bwd_kernel_one_col_block(
start_m = tl.multiple_of(start_m, BLOCK_M)
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
if EVEN_M:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
# recompute p = softmax(qk, dim=-1).T
qk = tl.dot(q, k, trans_b=True)
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
@ -235,7 +236,10 @@ def _bwd_kernel_one_col_block(
lse_i = tl.load(LSE + offs_m_curr)
p = tl.exp(qk * softmax_scale - lse_i[:, None])
# compute dv
do = tl.load(do_ptrs)
if EVEN_M:
do = tl.load(do_ptrs)
else:
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
# compute dp = dot(v, do)
dp = tl.dot(do, v, trans_b=True)
@ -249,12 +253,22 @@ def _bwd_kernel_one_col_block(
dk += tl.dot(ds, q, trans_a=True)
# compute dq
if not ATOMIC_ADD:
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
if EVEN_M:
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
else:
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
eviction_policy="evict_last")
dq += tl.dot(ds, k)
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
eviction_policy="evict_last")
else: # If we're parallelizing across the seqlen_k dimension
dq = tl.dot(ds, k)
tl.atomic_add(dq_ptrs, dq)
if EVEN_M:
tl.atomic_add(dq_ptrs, dq)
else:
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
# increment pointers
dq_ptrs += BLOCK_M * stride_dqm
q_ptrs += BLOCK_M * stride_qm
@ -417,7 +431,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
do = do.contiguous()
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
assert seqlen_q % 128 == 0, 'Backward pass currently only supports seqlens that are multiples of 128'
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
assert lse.shape == (batch, nheads, seqlen_q_rounded)
# dq_accum = torch.zeros_like(q, dtype=torch.float32)

View File

@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 211)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 128)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
@ -887,25 +887,22 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
run_bwd = seqlen_q % 128 == 0
if run_bwd:
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
if run_bwd:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()