[Triton] Avoid einops repeat by using Tensor.expand
This commit is contained in:
parent
88c4e5dbf6
commit
6b5f271c6d
@ -38,8 +38,6 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
|
||||
else:
|
||||
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
||||
' or (seqlen_q, seqlen_k)')
|
||||
if bias.shape[:2] == (1, nheads):
|
||||
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
||||
elif bias.shape[:2] == (batch, 1):
|
||||
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
||||
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
|
||||
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
||||
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
||||
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
|
||||
else:
|
||||
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
|
||||
' or (seqlen_q, seqlen_k)')
|
||||
if bias.shape[:2] == (1, nheads):
|
||||
bias = repeat(bias, '1 h ... -> b h ...', b=batch)
|
||||
elif bias.shape[:2] == (batch, 1):
|
||||
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
|
||||
assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)'
|
||||
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
|
||||
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
|
||||
|
||||
# BLOCK_M = 128
|
||||
|
||||
Loading…
Reference in New Issue
Block a user