From 13403e81157ba37ca525890f2f0f2137edf75311 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 11 Sep 2022 12:09:43 -0700 Subject: [PATCH] Relax assert to allow both bf16 and fp16 --- flash_attn/flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/flash_attention.py b/flash_attn/flash_attention.py index 2b70ea8..0e110a3 100644 --- a/flash_attn/flash_attention.py +++ b/flash_attn/flash_attention.py @@ -34,7 +34,7 @@ class FlashAttention(nn.Module): key_padding_mask: a bool tensor of shape (B, S) """ assert not need_weights - assert qkv.dtype == torch.float16 + assert qkv.dtype in [torch.float16, torch.bfloat16] assert qkv.is_cuda if cu_seqlens is None: