From 3f1b4d38e7c9ba56a333f2a6e2afe65b844c8da5 Mon Sep 17 00:00:00 2001 From: SueJane <80026679+SueJane@users.noreply.github.com> Date: Mon, 5 Aug 2024 23:59:23 +0800 Subject: [PATCH] Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127) --- flash_attn/modules/mha.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 1ded4c1..77640c2 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -196,7 +196,7 @@ class FlashCrossAttention(nn.Module): assert cu_seqlens_k is not None assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None - assert isinstance(max_seqlen, int) + assert isinstance(max_seqlen_k, int) return flash_attn_varlen_kvpacked_func( q, kv,