Fix Triton fwd to support seqlen not multiples of 128

This commit is contained in:
Tri Dao 2022-10-30 19:04:00 -07:00
parent b0c0db81f6
commit d11341fd1a
2 changed files with 31 additions and 23 deletions

View File

@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
Changes:
- Support both causal and non-causal attention.
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
- Make the backward for d=128 much faster by reducing register spilling.
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
@ -30,7 +31,7 @@ import triton.language as tl
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
}
)
@triton.jit
@ -42,7 +43,7 @@ def _fwd_kernel(
stride_kb, stride_kh, stride_kn,
stride_vb, stride_vh, stride_vn,
stride_ob, stride_oh, stride_om,
nheads, seqlen_q, seqlen_k,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
@ -68,12 +69,14 @@ def _fwd_kernel(
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
# initialize pointer to m and l
t_ptrs = TMP + off_hb * seqlen_q + offs_m
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
# load q: it will stay in SRAM throughout
if EVEN_M:
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
if EVEN_M & EVEN_N:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
@ -130,7 +133,7 @@ def _fwd_kernel(
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
lse_ptrs = Lse + off_hb * seqlen_q + offs_m
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_HEADDIM)
@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
k.stride(0), k.stride(2), k.stride(1),
v.stride(0), v.stride(2), v.stride(1),
o.stride(0), o.stride(2), o.stride(1),
nheads, seqlen_q, seqlen_k,
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,

View File

@ -855,15 +855,17 @@ def test_flash_attn_multigpu():
from flash_attn.flash_attn_triton import flash_attn_func
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(127, 256)])
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
@ -885,22 +887,25 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
run_bwd = (seqlen_q % 128 == 0) and (seqlen_k % 128 == 0)
if run_bwd:
g = torch.randn_like(output)
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
if run_bwd:
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()