Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127)

This commit is contained in:
SueJane 2024-08-05 23:59:23 +08:00 committed by GitHub
parent 3f6ff1c1c5
commit 3f1b4d38e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -196,7 +196,7 @@ class FlashCrossAttention(nn.Module):
assert cu_seqlens_k is not None assert cu_seqlens_k is not None
assert cu_seqlens_k.dtype == torch.int32 assert cu_seqlens_k.dtype == torch.int32
assert max_seqlen_k is not None assert max_seqlen_k is not None
assert isinstance(max_seqlen, int) assert isinstance(max_seqlen_k, int)
return flash_attn_varlen_kvpacked_func( return flash_attn_varlen_kvpacked_func(
q, q,
kv, kv,