From b410d14f28c419fba626ada20781f3677931a63d Mon Sep 17 00:00:00 2001 From: eric-tc-wong <85085599+eric-tc-wong@users.noreply.github.com> Date: Tue, 6 Sep 2022 17:29:49 -0400 Subject: [PATCH] Update flash_attention.py Recasting query and key after rotary_emb() --- flash_attn/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py index 040e623..2b70ea8 100644 --- a/flash_attn/flash_attention.py +++ b/flash_attn/flash_attention.py @@ -107,7 +107,7 @@ class FlashMHA(nn.Module): query, key, value = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads).unbind(dim=2) query, key = self.rotary_emb(query, key, seq_dimension=-3) - qkv = torch.stack([query, key, value], dim=2) + qkv = torch.stack([query.type(x.dtype), key.type(x.dtype), value], dim=2) else: qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask,