fix: cast the alibi slopes to torch.float32 (#846)

This commit is contained in:
Markus Krimmel 2024-03-15 08:49:40 +01:00 committed by GitHub
parent 4a73e903da
commit 6bbc532388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module):
assert qkv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None
@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module):
assert q.is_cuda and kv.is_cuda
causal = self.causal if causal is None else causal
unpadded = cu_seqlens is not None
if self.alibi_slopes is not None:
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
if unpadded:
assert cu_seqlens.dtype == torch.int32
assert max_seqlen is not None