Implement FlashAttention in Triton
This commit is contained in:
parent
c422fee377
commit
b0c0db81f6
@ -6,9 +6,11 @@ import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from flash_attn.utils.benchmark import benchmark_all, pytorch_profiler
|
||||
from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
||||
from flash_attn.triton.fused_attention import attention as attention
|
||||
# from flash_attn.triton.fused_attention import attention as attention
|
||||
from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func
|
||||
from flash_attn.flash_attn_triton_og import attention as attention_og
|
||||
|
||||
try:
|
||||
from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax
|
||||
@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True):
|
||||
return output.to(dtype=qkv.dtype)
|
||||
|
||||
|
||||
def attention_triton(q, k, v):
|
||||
"""
|
||||
No dropout and only support causal=True.
|
||||
Triton implementation seems to require q, k, v being contiguous?
|
||||
Arguments:
|
||||
q, k, v: (batch_size, nheads, seqlen, head_dim)
|
||||
Output:
|
||||
output: (batch_size, nheads, seqlen, head_dim)
|
||||
"""
|
||||
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
|
||||
return attention(q, k, v, softmax_scale)
|
||||
|
||||
|
||||
def attention_megatron(qkv):
|
||||
"""
|
||||
Arguments:
|
||||
@ -85,6 +74,10 @@ batch_size = 2
|
||||
seqlen = 4096
|
||||
nheads = 12
|
||||
headdim = 128
|
||||
# batch_size = 64
|
||||
# seqlen = 512
|
||||
# nheads = 8
|
||||
# headdim = 128
|
||||
dropout_p = 0.0
|
||||
causal = True
|
||||
dtype = torch.bfloat16
|
||||
@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b
|
||||
benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
|
||||
repeats=repeats, desc='PyTorch Attention')
|
||||
|
||||
benchmark_all(flash_attn_qkvpacked_func, qkv, causal, repeats=repeats, desc='FlashAttention Triton')
|
||||
pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal, backward=True)
|
||||
|
||||
q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
|
||||
requires_grad=True) for _ in range(3)]
|
||||
benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton')
|
||||
benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG')
|
||||
# pytorch_profiler(attention, q, k, v, 1.0, backward=True)
|
||||
|
||||
if scaled_upper_triang_masked_softmax is not None:
|
||||
benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention')
|
||||
|
||||
529
flash_attn/flash_attn_triton.py
Normal file
529
flash_attn/flash_attn_triton.py
Normal file
@ -0,0 +1,529 @@
|
||||
"""
|
||||
Based on the FlashAttention implementation from Phil Tillet.
|
||||
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
|
||||
Changes:
|
||||
- Support both causal and non-causal attention.
|
||||
- Speed up the forward pass a bit (and only store the LSE instead of m and l).
|
||||
- Make the backward for d=128 much faster by reducing register spilling.
|
||||
- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of
|
||||
small batch size * nheads.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
|
||||
],
|
||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM']
|
||||
)
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, Out,
|
||||
Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
|
||||
softmax_scale,
|
||||
stride_qb, stride_qh, stride_qm,
|
||||
stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn,
|
||||
stride_ob, stride_oh, stride_om,
|
||||
nheads, seqlen_q, seqlen_k,
|
||||
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# off_b = tl.program_id(1)
|
||||
# off_h = tl.program_id(2)
|
||||
# off_hb = off_b * nheads + off_h
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
# Initialize pointers to Q, K, V
|
||||
# Adding parenthesis around indexing might use int32 math instead of int64 math?
|
||||
# https://github.com/openai/triton/issues/741
|
||||
# I'm seeing a tiny bit of difference (5-7us)
|
||||
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
# initialize pointer to m and l
|
||||
t_ptrs = TMP + off_hb * seqlen_q + offs_m
|
||||
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
if EVEN_M:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||
# loop over k, v and update accumulator
|
||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
if EVEN_N:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k, trans_b=True)
|
||||
if not EVEN_N:
|
||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
||||
# Slightly faster to multiply the softmax_scale here since the compiler can then
|
||||
# fuse the mult and add into an fma instruction.
|
||||
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
|
||||
# scale acc_o
|
||||
acc_o_scale = tl.exp(m_i - m_ij)
|
||||
|
||||
# # -- update output accumulator --
|
||||
# BUG: have to store and immediately load
|
||||
tl.store(t_ptrs, acc_o_scale)
|
||||
acc_o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * acc_o_scale[:, None]
|
||||
# update acc_o
|
||||
if EVEN_N:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0)
|
||||
p = p.to(v.dtype)
|
||||
acc_o += tl.dot(p, v)
|
||||
|
||||
# -- update statistics
|
||||
m_i = m_ij
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
# BUG: have to store and immediately load
|
||||
tl.store(t_ptrs, o_scale)
|
||||
o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * o_scale[:, None]
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
lse_ptrs = Lse + off_hb * seqlen_q + offs_m
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_HEADDIM)
|
||||
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :])
|
||||
if EVEN_M:
|
||||
tl.store(out_ptrs, acc_o)
|
||||
else:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _bwd_preprocess_do_o_dot(
|
||||
Out, DO, Delta,
|
||||
stride_ob, stride_oh, stride_om,
|
||||
stride_dob, stride_doh, stride_dom,
|
||||
nheads, seqlen_q, seqlen_q_rounded,
|
||||
EVEN_M: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
# load
|
||||
if EVEN_M:
|
||||
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :]).to(tl.float32)
|
||||
else:
|
||||
o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
|
||||
mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32)
|
||||
do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :],
|
||||
mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, softmax_scale,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k,
|
||||
ATOMIC_ADD: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
|
||||
begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
|
||||
# initialize row/col offsets
|
||||
offs_qm = begin_m + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
offs_k = tl.arange(0, BLOCK_HEADDIM)
|
||||
# initialize pointers to value-like data
|
||||
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :])
|
||||
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :])
|
||||
v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :])
|
||||
do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_k[None, :])
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_k[None, :])
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
# k and v stay in SRAM throughout
|
||||
k = tl.load(k_ptrs)
|
||||
v = tl.load(v_ptrs)
|
||||
# loop over rows
|
||||
num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
|
||||
for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
|
||||
start_m = tl.multiple_of(start_m, BLOCK_M)
|
||||
offs_m_curr = start_m + offs_m
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
qk = tl.dot(q, k, trans_b=True)
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
|
||||
lse_i = tl.load(LSE + offs_m_curr)
|
||||
p = tl.exp(qk * softmax_scale - lse_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
dp = tl.dot(do, v, trans_b=True)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
# Putting the subtraction after the dp matmul (instead of before) is slightly faster
|
||||
Di = tl.load(D + offs_m_curr)
|
||||
# Converting ds to q.dtype here reduces register pressure and makes it much faster
|
||||
# for BLOCK_HEADDIM=128
|
||||
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(ds, q, trans_a=True)
|
||||
# compute dq
|
||||
if not ATOMIC_ADD:
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds, k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
else: # If we're parallelizing across the seqlen_k dimension
|
||||
dq = tl.dot(ds, k)
|
||||
tl.atomic_add(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_dqm
|
||||
q_ptrs += BLOCK_M * stride_qm
|
||||
do_ptrs += BLOCK_M * stride_dom
|
||||
# write-back
|
||||
dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :])
|
||||
dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :])
|
||||
tl.store(dv_ptrs, dv)
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
|
||||
def init_to_zero(name):
|
||||
# def fn(nargs):
|
||||
# with torch.no_grad():
|
||||
# nargs[name].zero_()
|
||||
# return fn
|
||||
return lambda nargs: nargs[name].zero_()
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1),
|
||||
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1),
|
||||
],
|
||||
key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'],
|
||||
# reset_to_zero=['DQ']
|
||||
)
|
||||
@triton.jit
|
||||
def _bwd_kernel(
|
||||
Q, K, V,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
softmax_scale,
|
||||
stride_qb, stride_qh, stride_qm,
|
||||
stride_kb, stride_kh, stride_kn,
|
||||
stride_vb, stride_vh, stride_vn,
|
||||
stride_dob, stride_doh, stride_dom,
|
||||
stride_dqb, stride_dqh, stride_dqm,
|
||||
stride_dkb, stride_dkh, stride_dkn,
|
||||
stride_dvb, stride_dvh, stride_dvn,
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
|
||||
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
|
||||
):
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
# offset pointers for batch/head
|
||||
Q += off_b * stride_qb + off_h * stride_qh
|
||||
K += off_b * stride_kb + off_h * stride_kh
|
||||
V += off_b * stride_vb + off_h * stride_vh
|
||||
DO += off_b * stride_dob + off_h * stride_doh
|
||||
DQ += off_b * stride_dqb + off_h * stride_dqh
|
||||
DK += off_b * stride_dkb + off_h * stride_dkh
|
||||
DV += off_b * stride_dvb + off_h * stride_dvh
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D += off_hb * seqlen_q_rounded
|
||||
LSE += off_hb * seqlen_q_rounded
|
||||
if not SEQUENCE_PARALLEL:
|
||||
num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
|
||||
for start_n in range(0, num_block_n):
|
||||
_bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, softmax_scale,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k,
|
||||
ATOMIC_ADD=False,
|
||||
IS_CAUSAL=IS_CAUSAL,
|
||||
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
||||
)
|
||||
else:
|
||||
start_n = tl.program_id(0)
|
||||
_bwd_kernel_one_col_block(
|
||||
start_n,
|
||||
Q, K, V, softmax_scale,
|
||||
DO, DQ, DK, DV,
|
||||
LSE, D,
|
||||
stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn,
|
||||
seqlen_q, seqlen_k,
|
||||
ATOMIC_ADD=True,
|
||||
IS_CAUSAL=IS_CAUSAL,
|
||||
BLOCK_HEADDIM=BLOCK_HEADDIM,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
||||
)
|
||||
|
||||
|
||||
def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None):
|
||||
# shape constraints
|
||||
batch, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, _, _ = k.shape
|
||||
assert k.shape == (batch, seqlen_k, nheads, d)
|
||||
assert v.shape == (batch, seqlen_k, nheads, d)
|
||||
assert d in {16, 32, 64, 128}
|
||||
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
|
||||
assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16'
|
||||
assert q.is_cuda and k.is_cuda and v.is_cuda
|
||||
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
# lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device,
|
||||
# dtype=torch.float32)
|
||||
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
o = torch.empty_like(q)
|
||||
|
||||
# BLOCK = 128
|
||||
# num_warps = 4 if d <= 64 else 8
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, o,
|
||||
lse, tmp,
|
||||
softmax_scale,
|
||||
q.stride(0), q.stride(2), q.stride(1),
|
||||
k.stride(0), k.stride(2), k.stride(1),
|
||||
v.stride(0), v.stride(2), v.stride(1),
|
||||
o.stride(0), o.stride(2), o.stride(1),
|
||||
nheads, seqlen_q, seqlen_k,
|
||||
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
|
||||
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
||||
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
||||
causal, d,
|
||||
# BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
# num_warps=num_warps,
|
||||
# num_stages=1,
|
||||
)
|
||||
return o, lse, softmax_scale # softmax_scale could have been updated
|
||||
|
||||
|
||||
def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_scale=None):
|
||||
# Make sure that the last dimension is contiguous
|
||||
if do.stride(-1) != 1:
|
||||
do = do.contiguous()
|
||||
batch, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, _, _ = k.shape
|
||||
assert seqlen_q % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128'
|
||||
assert seqlen_k % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128'
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
assert lse.shape == (batch, nheads, seqlen_q_rounded)
|
||||
# dq_accum = torch.zeros_like(q, dtype=torch.float32)
|
||||
dq_accum = torch.empty_like(q, dtype=torch.float32)
|
||||
delta = torch.empty_like(lse)
|
||||
# delta = torch.zeros_like(lse)
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
_bwd_preprocess_do_o_dot[grid](
|
||||
o, do, delta,
|
||||
o.stride(0), o.stride(2), o.stride(1),
|
||||
do.stride(0), do.stride(2), do.stride(1),
|
||||
nheads, seqlen_q, seqlen_q_rounded,
|
||||
BLOCK_M=128, BLOCK_HEADDIM=d,
|
||||
)
|
||||
|
||||
# TODO: There are 2 Memcpy DtoD when I use the autotuner.
|
||||
# BLOCK_M = 128
|
||||
# BLOCK_N = 64
|
||||
# num_warps = 4
|
||||
grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
|
||||
batch * nheads)
|
||||
_bwd_kernel[grid](
|
||||
q, k, v,
|
||||
do, dq_accum, dk, dv,
|
||||
lse, delta,
|
||||
softmax_scale,
|
||||
q.stride(0), q.stride(2), q.stride(1),
|
||||
k.stride(0), k.stride(2), k.stride(1),
|
||||
v.stride(0), v.stride(2), v.stride(1),
|
||||
do.stride(0), do.stride(2), do.stride(1),
|
||||
dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1),
|
||||
dk.stride(0), dk.stride(2), dk.stride(1),
|
||||
dv.stride(0), dv.stride(2), dv.stride(1),
|
||||
nheads, seqlen_q, seqlen_k, seqlen_q_rounded,
|
||||
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
|
||||
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
|
||||
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
|
||||
causal, d,
|
||||
# SEQUENCE_PARALLEL=False,
|
||||
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
||||
# num_warps=num_warps,
|
||||
# num_stages=1,
|
||||
)
|
||||
dq.copy_(dq_accum)
|
||||
|
||||
|
||||
class FlashAttnQKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, qkv, causal=False, softmax_scale=None):
|
||||
"""
|
||||
qkv: (batch, seqlen, 3, nheads, headdim)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
if qkv.stride(-1) != 1:
|
||||
qkv = qkv.contiguous()
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
||||
qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], causal=causal, softmax_scale=softmax_scale
|
||||
)
|
||||
ctx.save_for_backward(qkv, o, lse)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
qkv, o, lse = ctx.saved_tensors
|
||||
dqkv = torch.empty_like(qkv)
|
||||
_flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse,
|
||||
dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dqkv, None, None
|
||||
|
||||
|
||||
flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
|
||||
|
||||
|
||||
class FlashAttnKVPackedFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, kv, causal=False, softmax_scale=None):
|
||||
"""
|
||||
q: (batch, seqlen, nheads, headdim)
|
||||
kv: (batch, seqlen, 2, nheads, headdim)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(
|
||||
q, kv[:, :, 0], kv[:, :, 1], causal=causal, softmax_scale=softmax_scale
|
||||
)
|
||||
ctx.save_for_backward(q, kv, o, lse)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, kv, o, lse = ctx.saved_tensors
|
||||
dq = torch.empty_like(q)
|
||||
dkv = torch.empty_like(kv)
|
||||
_flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse,
|
||||
dq, dkv[:, :, 0], dkv[:, :, 1],
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dkv, None, None
|
||||
|
||||
|
||||
flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
|
||||
|
||||
|
||||
class FlashAttnFunc(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal=False, softmax_scale=None):
|
||||
"""
|
||||
q, k, v: (batch_size, seqlen, nheads, headdim)
|
||||
"""
|
||||
# Make sure that the last dimension is contiguous
|
||||
q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
|
||||
o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, causal=causal,
|
||||
softmax_scale=softmax_scale)
|
||||
ctx.save_for_backward(q, k, v, o, lse)
|
||||
ctx.causal = causal
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
q, k, v, o, lse = ctx.saved_tensors
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
_flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
|
||||
causal=ctx.causal, softmax_scale=ctx.softmax_scale)
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
|
||||
flash_attn_func = FlashAttnFunc.apply
|
||||
@ -1,6 +1,6 @@
|
||||
# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
|
||||
# for benchmarking.
|
||||
# Fixing some dtype casting to make it work for bfloat16
|
||||
# We fixed a few dtype cast to make it work for bf16
|
||||
|
||||
"""
|
||||
Fused Attention
|
||||
@ -78,7 +78,7 @@ def _fwd_kernel(
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(q.dtype)
|
||||
p = p.to(v.dtype)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
@ -178,7 +178,7 @@ def _bwd_kernel(
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(p.to(q.dtype), do, trans_a=True)
|
||||
dv += tl.dot(p.to(do.dtype), do, trans_a=True)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
@ -189,7 +189,7 @@ def _bwd_kernel(
|
||||
dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
|
||||
# # compute dq
|
||||
dq = tl.load(dq_ptrs, eviction_policy="evict_last")
|
||||
dq += tl.dot(ds.to(q.dtype), k)
|
||||
dq += tl.dot(ds.to(k.dtype), k)
|
||||
tl.store(dq_ptrs, dq, eviction_policy="evict_last")
|
||||
# # increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
@ -270,95 +270,7 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None
|
||||
return dq.to(q.dtype), dk, dv, None
|
||||
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
|
||||
sm_scale = 0.3
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
ref_dv, v.grad = v.grad.clone(), None
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# triton implementation
|
||||
tri_out = attention(q, k, v, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
HAS_FLASH = True
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 16)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
styles=[('red', '-'), ('blue', '-')],
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ['fwd', 'bwd']
|
||||
warmup = 25
|
||||
rep = 100
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
if provider == "flash":
|
||||
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
|
||||
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
|
||||
cu_seqlens[1:] = lengths.cumsum(0)
|
||||
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
|
||||
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
fn = lambda: o.backward(do, retain_graph=True)
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# only works on A100 at the moment
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo
|
||||
# output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
|
||||
if dropout_mask is not None:
|
||||
attention_drop = attention.masked_fill(~dropout_mask, 0.0)
|
||||
else:
|
||||
attention_drop = attention
|
||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling)
|
||||
if query_padding_mask is not None:
|
||||
output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0)
|
||||
@ -849,3 +851,56 @@ def test_flash_attn_multigpu():
|
||||
assert 0.99 <= dropout_fraction / dropout_p <= 1.01
|
||||
|
||||
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()
|
||||
|
||||
|
||||
from flash_attn.flash_attn_triton import flash_attn_func
|
||||
|
||||
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
|
||||
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||
@pytest.mark.parametrize('causal', [False, True])
|
||||
# @pytest.mark.parametrize('causal', [True])
|
||||
@pytest.mark.parametrize('d', [64, 128])
|
||||
# @pytest.mark.parametrize('d', [64])
|
||||
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
|
||||
@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)])
|
||||
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)])
|
||||
def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype):
|
||||
if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
|
||||
pytest.skip() # Reference implementation OOM
|
||||
device = 'cuda'
|
||||
# set seed
|
||||
torch.random.manual_seed(0)
|
||||
batch_size = 8
|
||||
nheads = 4
|
||||
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype)
|
||||
k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2)
|
||||
|
||||
q, k, v = [x.detach().requires_grad_() for x in [q, k, v]]
|
||||
output = flash_attn_func(q, k, v, causal)
|
||||
|
||||
output_ref, attn_ref = attention_ref(q, k, v, causal=causal)
|
||||
output_pt, attn_pt = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True)
|
||||
print(f'Output max diff: {(output - output_ref).abs().max().item()}')
|
||||
print(f'Output mean diff: {(output - output_ref).abs().mean().item()}')
|
||||
print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}')
|
||||
print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}')
|
||||
|
||||
g = torch.randn_like(output)
|
||||
dq, dk, dv = torch.autograd.grad(output, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g)
|
||||
dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g)
|
||||
print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}')
|
||||
print(f'dK max diff: {(dk - dk_ref).abs().max().item()}')
|
||||
print(f'dV max diff: {(dv - dv_ref).abs().max().item()}')
|
||||
print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}')
|
||||
print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}')
|
||||
print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}')
|
||||
|
||||
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
|
||||
# assert torch.allclose(output, output_ref, rtol=rtol, atol=atol)
|
||||
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user