From 6bbc532388e61185a92e2a563126739967b4c8c5 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Fri, 15 Mar 2024 08:49:40 +0100 Subject: [PATCH] fix: cast the alibi slopes to torch.float32 (#846) --- flash_attn/modules/mha.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 89c7680..1ded4c1 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module): assert qkv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None @@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module): assert q.is_cuda and kv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None