diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py index 0e110a3..39bbfa6 100644 --- a/flash_attn/flash_attention.py +++ b/flash_attn/flash_attention.py @@ -4,7 +4,6 @@ import torch.nn as nn from einops import rearrange -from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis @@ -75,7 +74,7 @@ class FlashAttention(nn.Module): class FlashMHA(nn.Module): def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, - causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None: + causal=False, device=None, dtype=None, **kwargs) -> None: assert batch_first factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() @@ -85,14 +84,7 @@ class FlashMHA(nn.Module): self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads - assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128" - - assert use_rotary_emb in [None, '1d', '2d'] - self.use_rotary_emb = use_rotary_emb - if self.use_rotary_emb == '1d': - self.rotary_emb = RotaryEmbedding(self.head_dim) - elif self.use_rotary_emb == '2d': - self.rotary_emb = RotaryEmbedding2D(self.head_dim) + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs) @@ -103,13 +95,7 @@ class FlashMHA(nn.Module): key_padding_mask: bool tensor of shape (batch, seqlen) """ qkv = self.Wqkv(x) - if self.use_rotary_emb: - query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, - h=self.num_heads).unbind(dim=2) - query, key = self.rotary_emb(query, key, seq_dimension=-3) - qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2) - else: - qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal) return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights