Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127)
This commit is contained in:
parent
3f6ff1c1c5
commit
3f1b4d38e7
@ -196,7 +196,7 @@ class FlashCrossAttention(nn.Module):
|
|||||||
assert cu_seqlens_k is not None
|
assert cu_seqlens_k is not None
|
||||||
assert cu_seqlens_k.dtype == torch.int32
|
assert cu_seqlens_k.dtype == torch.int32
|
||||||
assert max_seqlen_k is not None
|
assert max_seqlen_k is not None
|
||||||
assert isinstance(max_seqlen, int)
|
assert isinstance(max_seqlen_k, int)
|
||||||
return flash_attn_varlen_kvpacked_func(
|
return flash_attn_varlen_kvpacked_func(
|
||||||
q,
|
q,
|
||||||
kv,
|
kv,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user