Relax assert to allow both bf16 and fp16

This commit is contained in:
Tri Dao 2022-09-11 12:09:43 -07:00
parent 64f42cd057
commit 13403e8115

View File

@ -34,7 +34,7 @@ class FlashAttention(nn.Module):
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype == torch.float16
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None: