Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127)
This commit is contained in:
parent
3f6ff1c1c5
commit
3f1b4d38e7
@ -196,7 +196,7 @@ class FlashCrossAttention(nn.Module):
|
||||
assert cu_seqlens_k is not None
|
||||
assert cu_seqlens_k.dtype == torch.int32
|
||||
assert max_seqlen_k is not None
|
||||
assert isinstance(max_seqlen, int)
|
||||
assert isinstance(max_seqlen_k, int)
|
||||
return flash_attn_varlen_kvpacked_func(
|
||||
q,
|
||||
kv,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user