Allow headdim 128 in FlashMHA interface

This commit is contained in:
Tri Dao 2022-08-05 09:46:08 -07:00
parent 2ed471ecc4
commit 713ea302d7

View File

@ -24,20 +24,16 @@ class FlashAttention(nn.Module):
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
key_padding_mask: An implementation of BaseMask that encodes how
many query each sequence in the batch consists of
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert attn_mask is None
assert qkv.dtype == torch.float16
assert qkv.is_cuda
@ -55,10 +51,9 @@ class FlashAttention(nn.Module):
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
key_padding_mask_bool = key_padding_mask.bool_matrix
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
@ -90,7 +85,7 @@ class FlashMHA(nn.Module):
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
assert use_rotary_emb in [None, '1d', '2d']
self.use_rotary_emb = use_rotary_emb
@ -103,8 +98,10 @@ class FlashMHA(nn.Module):
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
need_weights=False):
def forward(self, x, key_padding_mask=None):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
qkv = self.Wqkv(x)
if self.use_rotary_emb:
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,