diff --git a/benchmarks/benchmark_causal.py b/benchmarks/benchmark_causal.py index a1969ab..b3713e0 100644 --- a/benchmarks/benchmark_causal.py +++ b/benchmarks/benchmark_causal.py @@ -6,9 +6,11 @@ import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn.utils.benchmark import benchmark_all, pytorch_profiler +from flash_attn.utils.benchmark import benchmark_forward, benchmark_all, pytorch_profiler from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.triton.fused_attention import attention as attention +# from flash_attn.triton.fused_attention import attention as attention +from flash_attn.flash_attn_triton import flash_attn_qkvpacked_func +from flash_attn.flash_attn_triton_og import attention as attention_og try: from flash_attn.fused_softmax import scaled_upper_triang_masked_softmax @@ -45,19 +47,6 @@ def attention_pytorch(qkv, dropout_p=0.0, causal=True): return output.to(dtype=qkv.dtype) -def attention_triton(q, k, v): - """ - No dropout and only support causal=True. - Triton implementation seems to require q, k, v being contiguous? - Arguments: - q, k, v: (batch_size, nheads, seqlen, head_dim) - Output: - output: (batch_size, nheads, seqlen, head_dim) - """ - softmax_scale = 1.0 / math.sqrt(q.shape[-1]) - return attention(q, k, v, softmax_scale) - - def attention_megatron(qkv): """ Arguments: @@ -85,6 +74,10 @@ batch_size = 2 seqlen = 4096 nheads = 12 headdim = 128 +# batch_size = 64 +# seqlen = 512 +# nheads = 8 +# headdim = 128 dropout_p = 0.0 causal = True dtype = torch.bfloat16 @@ -100,9 +93,13 @@ benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal, repeats=repeats, desc='PyTorch Attention') +benchmark_all(flash_attn_qkvpacked_func, qkv, causal, repeats=repeats, desc='FlashAttention Triton') +pytorch_profiler(flash_attn_qkvpacked_func, qkv, causal, backward=True) + q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype, requires_grad=True) for _ in range(3)] -benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton') +benchmark_all(attention_og, q, k, v, 1.0, repeats=repeats, desc='FlashAttention Triton OG') +# pytorch_profiler(attention, q, k, v, 1.0, backward=True) if scaled_upper_triang_masked_softmax is not None: benchmark_all(attention_megatron, qkv, repeats=repeats, desc='Megatron Attention') diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py new file mode 100644 index 0000000..489f56d --- /dev/null +++ b/flash_attn/flash_attn_triton.py @@ -0,0 +1,529 @@ +""" +Based on the FlashAttention implementation from Phil Tillet. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Support both causal and non-causal attention. +- Speed up the forward pass a bit (and only store the LSE instead of m and l). +- Make the backward for d=128 much faster by reducing register spilling. +- Add the option to parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. +""" + +import math + +import torch + +from einops import rearrange + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % (args["BLOCK_N"]) == 0, + } +) +@triton.jit +def _fwd_kernel( + Q, K, V, Out, + Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_ob, stride_oh, stride_om, + nheads, seqlen_q, seqlen_k, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + if EVEN_M: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + if not EVEN_N: + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + # Slightly faster to multiply the softmax_scale here since the compiler can then + # fuse the mult and add into an fma instruction. + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :]) + if EVEN_M: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + + +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + } +) +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, DO, Delta, + stride_ob, stride_oh, stride_om, + stride_dob, stride_doh, stride_dom, + nheads, seqlen_q, seqlen_q_rounded, + EVEN_M: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + if EVEN_M: + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :]).to(tl.float32) + else: + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=offs_m[:, None] < seqlen_q, other=0.0).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, K, V, softmax_scale, + DO, DQ, DK, DV, + LSE, D, + stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, + ATOMIC_ADD: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_k[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_k[None, :]) + # initialize dv amd dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + lse_i = tl.load(LSE + offs_m_curr) + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + dp = tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not ATOMIC_ADD: + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + tl.atomic_add(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_k[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_k[None, :]) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +def init_to_zero(name): + # def fn(nargs): + # with torch.no_grad(): + # nargs[name].zero_() + # return fn + return lambda nargs: nargs[name].zero_() + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'], + # reset_to_zero=['DQ'] +) +@triton.jit +def _bwd_kernel( + Q, K, V, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, K, V, softmax_scale, + DO, DQ, DK, DV, + LSE, D, + stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, + ATOMIC_ADD=False, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, K, V, softmax_scale, + DO, DQ, DK, DV, + LSE, D, + stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, + ATOMIC_ADD=True, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + + +def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d in {16, 32, 64, 128} + assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' + assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + # lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device, + # dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + # BLOCK = 128 + # num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, k, v, o, + lse, tmp, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + o.stride(0), o.stride(2), o.stride(1), + nheads, seqlen_q, seqlen_k, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + causal, d, + # BLOCK_M=BLOCK, BLOCK_N=BLOCK, + # num_warps=num_warps, + # num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_scale=None): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert seqlen_q % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128' + assert seqlen_k % 128 == 0, 'Backward pass currently only support seqlen that are multiples of 128' + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, do, delta, + o.stride(0), o.stride(2), o.stride(1), + do.stride(0), do.stride(2), do.stride(1), + nheads, seqlen_q, seqlen_q_rounded, + BLOCK_M=128, BLOCK_HEADDIM=d, + ) + + # TODO: There are 2 Memcpy DtoD when I use the autotuner. + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads) + _bwd_kernel[grid]( + q, k, v, + do, dq_accum, dk, dv, + lse, delta, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + do.stride(0), do.stride(2), do.stride(1), + dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + causal, d, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(qkv, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse = ctx.saved_tensors + dqkv = torch.empty_like(qkv) + _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, + dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dqkv, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, causal=False, softmax_scale=None): + """ + q: (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse = ctx.saved_tensors + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, + dq, dkv[:, :, 0], dkv[:, :, 1], + causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dkv, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal=False, softmax_scale=None): + """ + q, k, v: (batch_size, seqlen, nheads, headdim) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, causal=causal, + softmax_scale=softmax_scale) + ctx.save_for_backward(q, k, v, o, lse) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse = ctx.saved_tensors + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dk, dv, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/flash_attn/triton/fused_attention.py b/flash_attn/flash_attn_triton_og.py similarity index 72% rename from flash_attn/triton/fused_attention.py rename to flash_attn/flash_attn_triton_og.py index 48387bc..fb165d5 100644 --- a/flash_attn/triton/fused_attention.py +++ b/flash_attn/flash_attn_triton_og.py @@ -1,6 +1,6 @@ # [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py # for benchmarking. -# Fixing some dtype casting to make it work for bfloat16 +# We fixed a few dtype cast to make it work for bf16 """ Fused Attention @@ -78,7 +78,7 @@ def _fwd_kernel( acc = acc * acc_scale[:, None] # update acc v = tl.load(v_ptrs + start_n * stride_vk) - p = p.to(q.dtype) + p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new @@ -178,7 +178,7 @@ def _bwd_kernel( p = tl.exp(qk * sm_scale - m[:, None]) # compute dv do = tl.load(do_ptrs) - dv += tl.dot(p.to(q.dtype), do, trans_a=True) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) Di = tl.load(D_ptrs + offs_m_curr) dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] @@ -189,7 +189,7 @@ def _bwd_kernel( dk += tl.dot(ds.to(q.dtype), q, trans_a=True) # # compute dq dq = tl.load(dq_ptrs, eviction_policy="evict_last") - dq += tl.dot(ds.to(q.dtype), k) + dq += tl.dot(ds.to(k.dtype), k) tl.store(dq_ptrs, dq, eviction_policy="evict_last") # # increment pointers dq_ptrs += BLOCK_M * stride_qm @@ -270,95 +270,7 @@ class _attention(torch.autograd.Function): BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, num_stages=1, ) - return dq, dk, dv, None + return dq.to(q.dtype), dk, dv, None attention = _attention.apply - - -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)]) -def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): - torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_() - sm_scale = 0.3 - dout = torch.randn_like(q) - # reference implementation - M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) - p = torch.matmul(q, k.transpose(2, 3)) * sm_scale - for z in range(Z): - for h in range(H): - p[:, :, M == 0] = float("-inf") - p = torch.softmax(p.float(), dim=-1).half() - ref_out = torch.matmul(p, v) - ref_out.backward(dout) - ref_dv, v.grad = v.grad.clone(), None - ref_dk, k.grad = k.grad.clone(), None - ref_dq, q.grad = q.grad.clone(), None - # triton implementation - tri_out = attention(q, k, v, sm_scale) - tri_out.backward(dout) - tri_dv, v.grad = v.grad.clone(), None - tri_dk, k.grad = k.grad.clone(), None - tri_dq, q.grad = q.grad.clone(), None - # compare - triton.testing.assert_almost_equal(ref_out, tri_out) - triton.testing.assert_almost_equal(ref_dv, tri_dv) - triton.testing.assert_almost_equal(ref_dk, tri_dk) - triton.testing.assert_almost_equal(ref_dq, tri_dq) - - -try: - from flash_attn.flash_attn_interface import flash_attn_func - HAS_FLASH = True -except BaseException: - HAS_FLASH = False - -BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 -# vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 16)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} -) for mode in ['bwd']] - - -@triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"): - assert mode in ['fwd', 'bwd'] - warmup = 25 - rep = 100 - if provider == "triton": - q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True) - sm_scale = 1.3 - fn = lambda: attention(q, k, v, sm_scale) - if mode == 'bwd': - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) - return ms - if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) - qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) - fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) - if mode == 'bwd': - o = fn() - do = torch.randn_like(o) - fn = lambda: o.backward(do, retain_graph=True) - ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep) - return ms - -# only works on A100 at the moment -# bench_flash_attention.run(save_path='.', print_data=True) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index d093975..ae32668 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -160,6 +160,8 @@ def attention_ref(q, k, v, query_padding_mask=None, key_padding_mask=None, dropo # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention output = torch.einsum('bhts,bshd->bthd', attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, 'b s -> b s 1 1'), 0.0) @@ -849,3 +851,56 @@ def test_flash_attn_multigpu(): assert 0.99 <= dropout_fraction / dropout_p <= 1.01 assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() + + +from flash_attn.flash_attn_triton import flash_attn_func + +@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) +# @pytest.mark.parametrize('dtype', [torch.bfloat16]) +@pytest.mark.parametrize('causal', [False, True]) +# @pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('d', [64, 128]) +# @pytest.mark.parametrize('d', [64]) +# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]) +@pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 512), (512, 256), (1024, 1024), (2048, 2048)]) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(512, 256)]) +def test_flash_attn_triton(seqlen_q, seqlen_k, d, causal, dtype): + if seqlen_q >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30: + pytest.skip() # Reference implementation OOM + device = 'cuda' + # set seed + torch.random.manual_seed(0) + batch_size = 8 + nheads = 4 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) + k, v = torch.randn(batch_size, seqlen_k, 2, nheads, d, device=device, dtype=dtype).unbind(dim=2) + + q, k, v = [x.detach().requires_grad_() for x in [q, k, v]] + output = flash_attn_func(q, k, v, causal) + + output_ref, attn_ref = attention_ref(q, k, v, causal=causal) + output_pt, attn_pt = attention_ref(q, k, v, causal=causal, upcast=False, reorder_ops=True) + print(f'Output max diff: {(output - output_ref).abs().max().item()}') + print(f'Output mean diff: {(output - output_ref).abs().mean().item()}') + print(f'Pytorch max diff: {(output_pt - output_ref).abs().max().item()}') + print(f'Pytorch mean diff: {(output_pt - output_ref).abs().mean().item()}') + + g = torch.randn_like(output) + dq, dk, dv = torch.autograd.grad(output, (q, k, v), g) + dq_ref, dk_ref, dv_ref, = torch.autograd.grad(output_ref, (q, k, v), g) + dq_pt, dk_pt, dv_pt, = torch.autograd.grad(output_pt, (q, k, v), g) + print(f'dQ max diff: {(dq - dq_ref).abs().max().item()}') + print(f'dK max diff: {(dk - dk_ref).abs().max().item()}') + print(f'dV max diff: {(dv - dv_ref).abs().max().item()}') + print(f'dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}') + print(f'dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}') + print(f'dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}') + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() + # assert torch.allclose(output, output_ref, rtol=rtol, atol=atol) + + assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()