fix: cast the alibi slopes to torch.float32 (#846)
This commit is contained in:
parent
4a73e903da
commit
6bbc532388
@ -101,6 +101,8 @@ class FlashSelfAttention(nn.Module):
|
||||
assert qkv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if self.alibi_slopes is not None:
|
||||
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
assert max_seqlen is not None
|
||||
@ -185,6 +187,8 @@ class FlashCrossAttention(nn.Module):
|
||||
assert q.is_cuda and kv.is_cuda
|
||||
causal = self.causal if causal is None else causal
|
||||
unpadded = cu_seqlens is not None
|
||||
if self.alibi_slopes is not None:
|
||||
self.alibi_slopes = self.alibi_slopes.to(torch.float32)
|
||||
if unpadded:
|
||||
assert cu_seqlens.dtype == torch.int32
|
||||
assert max_seqlen is not None
|
||||
|
||||
Loading…
Reference in New Issue
Block a user