diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 89c7680..1ded4c1 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module): assert qkv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None @@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module): assert q.is_cuda and kv.is_cuda causal = self.causal if causal is None else causal unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) if unpadded: assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None