From 713ea302d7b3a191d200f9cb7fd174a3872a9b92 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 5 Aug 2022 09:46:08 -0700 Subject: [PATCH] Allow headdim 128 in FlashMHA interface --- flash_attn/flash_attention.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py index 27d30dd..a3693ee 100644 --- a/flash_attn/flash_attention.py +++ b/flash_attn/flash_attention.py @@ -24,20 +24,16 @@ class FlashAttention(nn.Module): self.softmax_scale = softmax_scale self.dropout_p = attention_dropout - def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None, + def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, max_s=None, need_weights=False): """Implements the multihead softmax attention. Arguments --------- qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None if unpadded: (nnz, 3, h, d) - attn_mask: An implementation of BaseMask that encodes where each - query can attend to - key_padding_mask: An implementation of BaseMask that encodes how - many query each sequence in the batch consists of + key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights - assert attn_mask is None assert qkv.dtype == torch.float16 assert qkv.is_cuda @@ -55,10 +51,9 @@ class FlashAttention(nn.Module): ) output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) else: - key_padding_mask_bool = key_padding_mask.bool_matrix nheads = qkv.shape[-2] x = rearrange(qkv, 'b s three h d -> b s (three h d)') - x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) output_unpad = flash_attn_unpadded_qkvpacked_func( x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, @@ -90,7 +85,7 @@ class FlashMHA(nn.Module): self.num_heads = num_heads assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" self.head_dim = self.embed_dim // num_heads - assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" + assert self.head_dim in [16, 32, 64, 128], "Only support head_dim == 16, 32, 64, or 128" assert use_rotary_emb in [None, '1d', '2d'] self.use_rotary_emb = use_rotary_emb @@ -103,8 +98,10 @@ class FlashMHA(nn.Module): self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) - def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, - need_weights=False): + def forward(self, x, key_padding_mask=None): + """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + key_padding_mask: bool tensor of shape (batch, seqlen) + """ qkv = self.Wqkv(x) if self.use_rotary_emb: query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3,