Support arbitrary seqlens (both q & k) in Triton bwd
This commit is contained in:
parent
dc55469355
commit
b910bf14c1
@ -5,10 +5,8 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
|
||||
Changes:
|
||||
- Implement both causal and non-causal attention.
|
||||
- Implement cross-attention (not just self-attention).
|
||||
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
|
||||
- Support arbitrary seqlen_k (not just multiples of 128) in the backward pass. However, seqlen_q
|
||||
must still be a multiple of 128.
|
||||
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
|
||||
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward.
|
||||
- Speed up the forward pass a bit, and only store the LSE instead of m and l.
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
small batch size * nheads.
|
||||
@ -18,8 +16,6 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@ -213,7 +209,9 @@ def _bwd_kernel_one_col_block(
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
if EVEN_N:
|
||||
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_N=False,
|
||||
# if we just call # tl.load(k_ptrs), we get the wrong output!
|
||||
if EVEN_N & EVEN_M:
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
else:
|
||||
@ -225,7 +223,10 @@ def _bwd_kernel_one_col_block(
|
||||
start_m = tl.multiple_of(start_m, BLOCK_M)
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
if EVEN_M:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
||||
@ -235,7 +236,10 @@ def _bwd_kernel_one_col_block(
|
||||
lse_i = tl.load(LSE + offs_m_curr)
|
||||
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
if EVEN_M:
|
||||
do = tl.load(do_ptrs)
|
||||
else:
|
||||
do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
@ -249,12 +253,22 @@ def _bwd_kernel_one_col_block(
|
||||
dk += tl.dot(ds, q, trans_a=True)
|
||||
# compute dq
|
||||
if not ATOMIC_ADD:
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
if EVEN_M:
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
else:
|
||||
dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0,
|
||||
eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q,
|
||||
eviction_policy="evict_last")
|
||||
else: # If we're parallelizing across the seqlen_k dimension
|
||||
dq = tl.dot(ds, k)
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
if EVEN_M:
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
else:
|
||||
tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_dqm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
@ -417,7 +431,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_
|
||||
do = do.contiguous()
|
||||
batch, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, _, _ = k.shape
|
||||
assert seqlen_q % 128 == 0, 'Backward pass currently only supports seqlens that are multiples of 128'
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
||||
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
|
||||
|
||||
@ -860,12 +860,12 @@ from flash_attn.flash_attn_triton import flash_attn_func
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize('d', [64, 128])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (108, 256), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 211)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 128)])
|
||||
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
@ -887,25 +887,22 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
|
||||
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
|
||||
|
||||
run_bwd = seqlen_q % 128 == 0
|
||||
if run_bwd:
|
||||
g = torch.randn_like(output)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
|
||||
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
|
||||
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
g = torch.randn_like(output)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
|
||||
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
|
||||
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
|
||||
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
|
||||
|
||||
if run_bwd:
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user