Allow headdim 128 in FlashMHA interface
This commit is contained in:
parent
2ed471ecc4
commit
713ea302d7
@ -24,20 +24,16 @@ class FlashAttention(nn.Module):
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||
max_s=None, need_weights=False):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
attn_mask: An implementation of BaseMask that encodes where each
|
||||
query can attend to
|
||||
key_padding_mask: An implementation of BaseMask that encodes how
|
||||
many query each sequence in the batch consists of
|
||||
key_padding_mask: a bool tensor of shape (B, S)
|
||||
"""
|
||||
assert not need_weights
|
||||
assert attn_mask is None
|
||||
assert qkv.dtype == torch.float16
|
||||
assert qkv.is_cuda
|
||||
|
||||
@ -55,10 +51,9 @@ class FlashAttention(nn.Module):
|
||||
)
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
else:
|
||||
key_padding_mask_bool = key_padding_mask.bool_matrix
|
||||
nheads = qkv.shape[-2]
|
||||
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
|
||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
||||
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
||||
@ -90,7 +85,7 @@ class FlashMHA(nn.Module):
|
||||
self.num_heads = num_heads
|
||||
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
||||
self.head_dim = self.embed_dim // num_heads
|
||||
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
|
||||
assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
|
||||
|
||||
assert use_rotary_emb in [None, '1d', '2d']
|
||||
self.use_rotary_emb = use_rotary_emb
|
||||
@ -103,8 +98,10 @@ class FlashMHA(nn.Module):
|
||||
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
|
||||
need_weights=False):
|
||||
def forward(self, x, key_padding_mask=None):
|
||||
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||
key_padding_mask: bool tensor of shape (batch, seqlen)
|
||||
"""
|
||||
qkv = self.Wqkv(x)
|
||||
if self.use_rotary_emb:
|
||||
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user