Remove RotaryEmbedding from FlashAttention module

To avoid import error if one doesn't have rotary_emb installed
This commit is contained in:
Tri Dao 2022-11-10 11:54:36 -08:00
parent 6998e0ecdb
commit 55797f32c9

View File

@ -4,7 +4,6 @@ import torch.nn as nn
from einops import rearrange
from flash_attn.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
@ -75,7 +74,7 @@ class FlashAttention(nn.Module):
class FlashMHA(nn.Module):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, use_rotary_emb=None, device=None, dtype=None, **kwargs) -> None:
causal=False, device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
@ -85,14 +84,7 @@ class FlashMHA(nn.Module):
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
assert use_rotary_emb in [None, '1d', '2d']
self.use_rotary_emb = use_rotary_emb
if self.use_rotary_emb == '1d':
self.rotary_emb = RotaryEmbedding(self.head_dim)
elif self.use_rotary_emb == '2d':
self.rotary_emb = RotaryEmbedding2D(self.head_dim)
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
@ -103,13 +95,7 @@ class FlashMHA(nn.Module):
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
qkv = self.Wqkv(x)
if self.use_rotary_emb:
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
h=self.num_heads).unbind(dim=2)
query, key = self.rotary_emb(query, key, seq_dimension=-3)
qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2)
else:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,
need_weights=need_weights, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights