[Triton] Avoid einops repeat by using Tensor.expand

This commit is contained in:
Tri Dao 2022-12-14 14:48:41 -08:00
parent 88c4e5dbf6
commit 6b5f271c6d

View File

@ -38,8 +38,6 @@ import math
import torch
from einops import rearrange, repeat
import triton
import triton.language as tl
@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
else:
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
' or (seqlen_q, seqlen_k)')
if bias.shape[:2] == (1, nheads):
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
elif bias.shape[:2] == (batch, 1):
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
# BLOCK_M = 128