Fix Triton fwd to support seqlen not multiples of 128
This commit is contained in:
parent
b0c0db81f6
commit
d11341fd1a
@ -4,6 +4,7 @@ https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention
|
||||
|
||||
Changes:
|
||||
- Support both causal and non-causal attention.
|
||||
- Support arbitrary seqlens (not just multiples of 128) in the forward pass.
|
||||
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
@ -30,7 +31,7 @@ import triton.language as tl
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
@ -42,7 +43,7 @@ def _fwd_kernel(
|
||||
stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn,
|
||||
stride_ob, stride_oh, stride_om,
|
||||
nheads, seqlen_q, seqlen_k,
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
|
||||
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
@ -68,12 +69,14 @@ def _fwd_kernel(
|
||||
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hb * seqlen_q + offs_m
|
||||
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
||||
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
if EVEN_M:
|
||||
# [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call
|
||||
# tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler?
|
||||
if EVEN_M & EVEN_N:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||
@ -130,7 +133,7 @@ def _fwd_kernel(
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
lse_ptrs = Lse + off_hb * seqlen_q + offs_m
|
||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_HEADDIM)
|
||||
@ -373,7 +376,7 @@ def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
|
||||
k.stride(0), k.stride(2), k.stride(1),
|
||||
v.stride(0), v.stride(2), v.stride(1),
|
||||
o.stride(0), o.stride(2), o.stride(1),
|
||||
nheads, seqlen_q, seqlen_k,
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
|
||||
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
|
||||
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
||||
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
||||
|
||||
@ -855,15 +855,17 @@ def test_flash_attn_multigpu():
|
||||
|
||||
from flash_attn.flash_attn_triton import flash_attn_func
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_sm80, reason='Triton version is only tested on A100')
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
# @pytest.mark.parametrize('dtype', [torch.float16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
# @pytest.mark.parametrize('causal', [False])
|
||||
@pytest.mark.parametrize('d', [64, 128])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(113, 203), (128, 217), (113, 211), (256, 512), (512, 256), (1024, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(127, 256)])
|
||||
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
@ -885,22 +887,25 @@ def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
|
||||
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
|
||||
|
||||
g = torch.randn_like(output)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
|
||||
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
|
||||
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
run_bwd = (seqlen_q % 128 == 0) and (seqlen_k % 128 == 0)
|
||||
if run_bwd:
|
||||
g = torch.randn_like(output)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
|
||||
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
|
||||
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
|
||||
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
|
||||
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
if run_bwd:
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user