diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py index 2b70ea8..0e110a3 100644 --- a/flash_attn/flash_attention.py +++ b/flash_attn/flash_attention.py @@ -34,7 +34,7 @@ class FlashAttention(nn.Module): key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights - assert qkv.dtype == torch.float16 + assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda if cu_seqlens is None: