diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 1ded4c1..77640c2 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -196,7 +196,7 @@ class FlashCrossAttention(nn.Module): assert cu_seqlens_k is not None assert cu_seqlens_k.dtype == torch.int32 assert max_seqlen_k is not None - assert isinstance(max_seqlen, int) + assert isinstance(max_seqlen_k, int) return flash_attn_varlen_kvpacked_func( q, kv,