Allow headdim 128 in FlashMHA interface
This commit is contained in:
parent
2ed471ecc4
commit
713ea302d7
@ -24,20 +24,16 @@ class FlashAttention(nn.Module):
|
|||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.dropout_p = attention_dropout
|
self.dropout_p = attention_dropout
|
||||||
|
|
||||||
def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None,
|
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
|
||||||
max_s=None, need_weights=False):
|
max_s=None, need_weights=False):
|
||||||
"""Implements the multihead softmax attention.
|
"""Implements the multihead softmax attention.
|
||||||
Arguments
|
Arguments
|
||||||
---------
|
---------
|
||||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||||
if unpadded: (nnz, 3, h, d)
|
if unpadded: (nnz, 3, h, d)
|
||||||
attn_mask: An implementation of BaseMask that encodes where each
|
key_padding_mask: a bool tensor of shape (B, S)
|
||||||
query can attend to
|
|
||||||
key_padding_mask: An implementation of BaseMask that encodes how
|
|
||||||
many query each sequence in the batch consists of
|
|
||||||
"""
|
"""
|
||||||
assert not need_weights
|
assert not need_weights
|
||||||
assert attn_mask is None
|
|
||||||
assert qkv.dtype == torch.float16
|
assert qkv.dtype == torch.float16
|
||||||
assert qkv.is_cuda
|
assert qkv.is_cuda
|
||||||
|
|
||||||
@ -55,10 +51,9 @@ class FlashAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||||
else:
|
else:
|
||||||
key_padding_mask_bool = key_padding_mask.bool_matrix
|
|
||||||
nheads = qkv.shape[-2]
|
nheads = qkv.shape[-2]
|
||||||
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
|
||||||
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool)
|
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
|
||||||
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
|
||||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
||||||
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
|
||||||
@ -90,7 +85,7 @@ class FlashMHA(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
|
||||||
self.head_dim = self.embed_dim // num_heads
|
self.head_dim = self.embed_dim // num_heads
|
||||||
assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64"
|
assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128"
|
||||||
|
|
||||||
assert use_rotary_emb in [None, '1d', '2d']
|
assert use_rotary_emb in [None, '1d', '2d']
|
||||||
self.use_rotary_emb = use_rotary_emb
|
self.use_rotary_emb = use_rotary_emb
|
||||||
@ -103,8 +98,10 @@ class FlashMHA(nn.Module):
|
|||||||
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
|
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
|
||||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||||
|
|
||||||
def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None,
|
def forward(self, x, key_padding_mask=None):
|
||||||
need_weights=False):
|
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
|
||||||
|
key_padding_mask: bool tensor of shape (batch, seqlen)
|
||||||
|
"""
|
||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
if self.use_rotary_emb:
|
if self.use_rotary_emb:
|
||||||
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
|
query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user