Relax assert to allow both bf16 and fp16
This commit is contained in:
parent
64f42cd057
commit
13403e8115
@ -34,7 +34,7 @@ class FlashAttention(nn.Module):
|
||||
key_padding_mask: a bool tensor of shape (B, S)
|
||||
"""
|
||||
assert not need_weights
|
||||
assert qkv.dtype == torch.float16
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
if cu_seqlens is None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user