diff --git a/flash_attn/flash_attn_triton.py b/flash_attn/flash_attn_triton.py index a284d8e..ebc1bf8 100644 --- a/flash_attn/flash_attn_triton.py +++ b/flash_attn/flash_attn_triton.py @@ -38,8 +38,6 @@ import math import torch -from einops import rearrange, repeat - import triton import triton.language as tl @@ -605,11 +603,7 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' ' or (seqlen_q, seqlen_k)') - if bias.shape[:2] == (1, nheads): - bias = repeat(bias, '1 h ... -> b h ...', b=batch) - elif bias.shape[:2] == (batch, 1): - bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) - assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)' + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 @@ -684,11 +678,7 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals else: raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' ' or (seqlen_q, seqlen_k)') - if bias.shape[:2] == (1, nheads): - bias = repeat(bias, '1 h ... -> b h ...', b=batch) - elif bias.shape[:2] == (batch, 1): - bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) - assert bias.shape[:2] == (batch, nheads), 'First 2 dimensions of bias must be broadcastible to (batch, nheads)' + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) # BLOCK_M = 128