""" Based on the FlashAttention implementation from Phil Tillet. https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py Changes: - Implement both causal and non-causal attention. - Implement cross-attention (not just self-attention). - Support arbitrary seqlens (not just multiples of 128), for both forward and backward. - [WIP] Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both the forward pass and backward pass. For the backward pass, head dims that are not 16, 32, 64, 128 will require more testing since there seems to be some race conditions due to the Triton compiler. - Speed up the forward pass a bit, and only store the LSE instead of m and l. - Make the backward for d=128 much faster by reducing register spilling. - Optionally parallelize the backward pass across seqlen_k, to deal with the case of small batch size * nheads. """ import math import torch import triton import triton.language as tl @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=8, num_stages=1), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), ], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'] ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _fwd_kernel( Q, K, V, Out, Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_ob, stride_oh, stride_om, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # off_b = tl.program_id(1) # off_h = tl.program_id(2) # off_hb = off_b * nheads + off_h # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_HEADDIM) # Initialize pointers to Q, K, V # Adding parenthesis around indexing might use int32 math instead of int64 math? # https://github.com/openai/triton/issues/741 # I'm seeing a tiny bit of difference (5-7us) q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) # initialize pointer to m and l t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # load q: it will stay in SRAM throughout # [2022-10-30] TD: Idk why but in the case of EVEN_M=True and EVEN_N=False, if we just call # tl.load(q_ptrs), we get the wrong output! Could be a bug in the compiler? if EVEN_M & EVEN_N: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) # loop over k, v and update accumulator end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) for start_n in range(0, end_n, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- if EVEN_N: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) if not EVEN_N: qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) # Slightly faster to multiply the softmax_scale here since the compiler can then # fuse the mult and add into an fma instruction. p = tl.exp(qk * softmax_scale - m_ij[:, None]) l_ij = tl.sum(p, 1) # scale acc_o acc_o_scale = tl.exp(m_i - m_ij) # # -- update output accumulator -- # BUG: have to store and immediately load tl.store(t_ptrs, acc_o_scale) acc_o_scale = tl.load(t_ptrs) acc_o = acc_o * acc_o_scale[:, None] # update acc_o if EVEN_N: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: v = tl.load(v_ptrs + start_n * stride_vn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) p = p.to(v.dtype) acc_o += tl.dot(p, v) # -- update statistics m_i = m_ij l_i_new = tl.exp(lse_i - m_ij) + l_ij lse_i = m_ij + tl.log(l_i_new) o_scale = tl.exp(m_i - lse_i) # BUG: have to store and immediately load tl.store(t_ptrs, o_scale) o_scale = tl.load(t_ptrs) acc_o = acc_o * o_scale[:, None] # rematerialize offsets to save registers start_m = tl.program_id(0) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # write back l and m lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m tl.store(lse_ptrs, lse_i) # initialize pointers to output offs_n = tl.arange(0, BLOCK_HEADDIM) out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_n[None, :]) if EVEN_M: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o) else: tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) else: tl.store(out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) @triton.jit def _bwd_preprocess_do_o_dot( Out, DO, Delta, stride_ob, stride_oh, stride_om, stride_dob, stride_doh, stride_dom, nheads, seqlen_q, seqlen_q_rounded, headdim, BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, ): start_m = tl.program_id(0) off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # load o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) @triton.jit def _bwd_kernel_one_col_block( start_n, Q, K, V, softmax_scale, DO, DQ, DK, DV, LSE, D, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M # initialize row/col offsets offs_qm = begin_m + tl.arange(0, BLOCK_M) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, BLOCK_HEADDIM) # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) # initialize dv amd dk dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) # k and v stay in SRAM throughout # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, # if we just call tl.load(k_ptrs), we get the wrong output! if EVEN_N & EVEN_M: if EVEN_HEADDIM: k = tl.load(k_ptrs) v = tl.load(v_ptrs) else: k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) # loop over rows num_block_m = tl.cdiv(seqlen_q, BLOCK_M) for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): start_m = tl.multiple_of(start_m, BLOCK_M) offs_m_curr = start_m + offs_m # load q, k, v, do on-chip if EVEN_M: if EVEN_HEADDIM: q = tl.load(q_ptrs) else: q = tl.load(q_ptrs, mask=(offs_d[None, :] < headdim)) else: if EVEN_HEADDIM: q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) # recompute p = softmax(qk, dim=-1).T qk = tl.dot(q, k, trans_b=True) if not EVEN_N: # Need to mask out otherwise the softmax is wrong qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) if IS_CAUSAL: qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. if not EVEN_HEADDIM: tl.debug_barrier() lse_i = tl.load(LSE + offs_m_curr) p = tl.exp(qk * softmax_scale - lse_i[:, None]) # compute dv # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, # the output is correct. if EVEN_M & EVEN_HEADDIM: do = tl.load(do_ptrs) # if EVEN_M: # if EVEN_HEADDIM: # do = tl.load(do_ptrs) # else: # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) else: if EVEN_HEADDIM: do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) else: do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0) dv += tl.dot(p.to(do.dtype), do, trans_a=True) # compute dp = dot(v, do) # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. if not EVEN_HEADDIM: tl.debug_barrier() dp = tl.dot(do, v, trans_b=True) # compute ds = p * (dp - delta[:, None]) # Putting the subtraction after the dp matmul (instead of before) is slightly faster Di = tl.load(D + offs_m_curr) # Converting ds to q.dtype here reduces register pressure and makes it much faster # for BLOCK_HEADDIM=128 ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) # compute dk = dot(ds.T, q) dk += tl.dot(ds, q, trans_a=True) # compute dq if not ATOMIC_ADD: if EVEN_M: if EVEN_HEADDIM: dq = tl.load(dq_ptrs, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, eviction_policy="evict_last") else: dq = tl.load(dq_ptrs, mask=offs_d[None, :] < headdim, other=0.0, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=offs_d[None, :] < headdim, eviction_policy="evict_last") else: if EVEN_HEADDIM: dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, eviction_policy="evict_last") else: dq = tl.load(dq_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0, eviction_policy="evict_last") dq += tl.dot(ds, k) tl.store(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), eviction_policy="evict_last") else: # If we're parallelizing across the seqlen_k dimension dq = tl.dot(ds, k) if EVEN_M: if EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq) else: tl.atomic_add(dq_ptrs, dq, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) else: tl.atomic_add(dq_ptrs, dq, mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) # increment pointers dq_ptrs += BLOCK_M * stride_dqm q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_dom # write-back dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) if EVEN_N: if EVEN_HEADDIM: tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) else: tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) else: if EVEN_HEADDIM: tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) else: tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) def init_to_zero(name): # def fn(nargs): # with torch.no_grad(): # nargs[name].zero_() # return fn return lambda nargs: nargs[name].zero_() @triton.autotune( configs=[ triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1), # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1), ], key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'IS_CAUSAL', 'BLOCK_HEADDIM'], # reset_to_zero=['DQ'] ) @triton.heuristics( { "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], } ) @triton.jit def _bwd_kernel( Q, K, V, DO, DQ, DK, DV, LSE, D, softmax_scale, stride_qb, stride_qh, stride_qm, stride_kb, stride_kh, stride_kn, stride_vb, stride_vh, stride_vn, stride_dob, stride_doh, stride_dom, stride_dqb, stride_dqh, stride_dqm, stride_dkb, stride_dkh, stride_dkn, stride_dvb, stride_dvh, stride_dvn, nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, IS_CAUSAL: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): off_hb = tl.program_id(1) off_b = off_hb // nheads off_h = off_hb % nheads # offset pointers for batch/head Q += off_b * stride_qb + off_h * stride_qh K += off_b * stride_kb + off_h * stride_kh V += off_b * stride_vb + off_h * stride_vh DO += off_b * stride_dob + off_h * stride_doh DQ += off_b * stride_dqb + off_h * stride_dqh DK += off_b * stride_dkb + off_h * stride_dkh DV += off_b * stride_dvb + off_h * stride_dvh # pointer to row-wise quantities in value-like data D += off_hb * seqlen_q_rounded LSE += off_hb * seqlen_q_rounded if not SEQUENCE_PARALLEL: num_block_n = tl.cdiv(seqlen_k, BLOCK_N) for start_n in range(0, num_block_n): _bwd_kernel_one_col_block( start_n, Q, K, V, softmax_scale, DO, DQ, DK, DV, LSE, D, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=False, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N ) else: start_n = tl.program_id(0) _bwd_kernel_one_col_block( start_n, Q, K, V, softmax_scale, DO, DQ, DK, DV, LSE, D, stride_qm, stride_kn, stride_vn, stride_dom, stride_dqm, stride_dkn, stride_dvn, seqlen_q, seqlen_k, headdim, ATOMIC_ADD=True, IS_CAUSAL=IS_CAUSAL, BLOCK_HEADDIM=BLOCK_HEADDIM, EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N ) def _flash_attn_forward(q, k, v, causal=False, softmax_scale=None): # shape constraints batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape assert k.shape == (batch, seqlen_k, nheads, d) assert v.shape == (batch, seqlen_k, nheads, d) assert d <= 128, 'FlashAttention only support head dimensions up to 128' assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' assert q.is_cuda and k.is_cuda and v.is_cuda softmax_scale = softmax_scale or 1.0 / math.sqrt(d) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) # lse = torch.full((batch, nheads, seqlen_q_rounded), float('inf'), device=q.device, # dtype=torch.float32) tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) o = torch.empty_like(q) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) # BLOCK = 128 # num_warps = 4 if d <= 64 else 8 grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _fwd_kernel[grid]( q, k, v, o, lse, tmp, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), o.stride(0), o.stride(2), o.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, causal, BLOCK_HEADDIM, # BLOCK_M=BLOCK, BLOCK_N=BLOCK, # num_warps=num_warps, # num_stages=1, ) return o, lse, softmax_scale # softmax_scale could have been updated def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=False, softmax_scale=None): # Make sure that the last dimension is contiguous if do.stride(-1) != 1: do = do.contiguous() batch, seqlen_q, nheads, d = q.shape _, seqlen_k, _, _ = k.shape # assert d in {16, 32, 64, 128} assert d <= 128 seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 assert lse.shape == (batch, nheads, seqlen_q_rounded) # dq_accum = torch.zeros_like(q, dtype=torch.float32) dq_accum = torch.empty_like(q, dtype=torch.float32) delta = torch.empty_like(lse) # delta = torch.zeros_like(lse) BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) _bwd_preprocess_do_o_dot[grid]( o, do, delta, o.stride(0), o.stride(2), o.stride(1), do.stride(0), do.stride(2), do.stride(1), nheads, seqlen_q, seqlen_q_rounded, d, BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, ) # TODO: There are 2 Memcpy DtoD when I use the autotuner. # BLOCK_M = 128 # BLOCK_N = 64 # num_warps = 4 grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, batch * nheads) _bwd_kernel[grid]( q, k, v, do, dq_accum, dk, dv, lse, delta, softmax_scale, q.stride(0), q.stride(2), q.stride(1), k.stride(0), k.stride(2), k.stride(1), v.stride(0), v.stride(2), v.stride(1), do.stride(0), do.stride(2), do.stride(1), dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), dk.stride(0), dk.stride(2), dk.stride(1), dv.stride(0), dv.stride(2), dv.stride(1), nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) # Can't use kwargs here because triton autotune expects key to be args, not kwargs # IS_CAUSAL=causal, BLOCK_HEADDIM=d, causal, BLOCK_HEADDIM, # SEQUENCE_PARALLEL=False, # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # num_warps=num_warps, # num_stages=1, ) dq.copy_(dq_accum) class FlashAttnQKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, qkv, causal=False, softmax_scale=None): """ qkv: (batch, seqlen, 3, nheads, headdim) """ # Make sure that the last dimension is contiguous if qkv.stride(-1) != 1: qkv = qkv.contiguous() o, lse, ctx.softmax_scale = _flash_attn_forward( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(qkv, o, lse) ctx.causal = causal return o @staticmethod def backward(ctx, do): qkv, o, lse = ctx.saved_tensors dqkv = torch.empty_like(qkv) _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dqkv, None, None flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply class FlashAttnKVPackedFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, kv, causal=False, softmax_scale=None): """ q: (batch, seqlen, nheads, headdim) kv: (batch, seqlen, 2, nheads, headdim) """ # Make sure that the last dimension is contiguous q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] o, lse, ctx.softmax_scale = _flash_attn_forward( q, kv[:, :, 0], kv[:, :, 1], causal=causal, softmax_scale=softmax_scale ) ctx.save_for_backward(q, kv, o, lse) ctx.causal = causal return o @staticmethod def backward(ctx, do): q, kv, o, lse = ctx.saved_tensors dq = torch.empty_like(q) dkv = torch.empty_like(kv) _flash_attn_backward(do, q, qkv[:, :, 0], qkv[:, :, 1], o, lse, dq, dkv[:, :, 0], dkv[:, :, 1], causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dq, dkv, None, None flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal=False, softmax_scale=None): """ q, k, v: (batch_size, seqlen, nheads, headdim) """ # Make sure that the last dimension is contiguous q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] o, lse, ctx.softmax_scale = _flash_attn_forward(q, k, v, causal=causal, softmax_scale=softmax_scale) ctx.save_for_backward(q, k, v, o, lse) ctx.causal = causal return o @staticmethod def backward(ctx, do): q, k, v, o, lse = ctx.saved_tensors dq = torch.empty_like(q) dk = torch.empty_like(k) dv = torch.empty_like(v) _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, causal=ctx.causal, softmax_scale=ctx.softmax_scale) return dq, dk, dv, None, None flash_attn_func = FlashAttnFunc.apply