[Rotary] Pass max_seqlen from mha.py to rotary during inference
This commit is contained in:
parent
942fcbf046
commit
de2949f37d
@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||
"""
|
||||
seqlen = qkv.shape[1]
|
||||
if isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
elif max_seqlen is not None:
|
||||
if max_seqlen is not None:
|
||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||
elif isinstance(seqlen_offset, int):
|
||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||
if kv is None:
|
||||
if self.scale is None:
|
||||
return apply_rotary_emb_qkv_(
|
||||
|
||||
@ -606,6 +606,9 @@ class MHA(nn.Module):
|
||||
else {"key_padding_mask": key_padding_mask, **kwargs}
|
||||
)
|
||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||
rotary_max_seqlen = (
|
||||
inference_params.max_sequene_len if inference_params is not None else None
|
||||
)
|
||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||
assert x_kv is None and mixer_subset is None
|
||||
if not self.return_residual:
|
||||
@ -623,7 +626,9 @@ class MHA(nn.Module):
|
||||
or not inference_params.fused_ft_kernel
|
||||
):
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||
qkv = self.rotary_emb(
|
||||
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||
)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
@ -669,7 +674,9 @@ class MHA(nn.Module):
|
||||
or not inference_params.fused_ft_kernel
|
||||
):
|
||||
if self.rotary_emb_dim > 0:
|
||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||
q, kv = self.rotary_emb(
|
||||
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||
)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||
@ -851,6 +858,9 @@ class ParallelMHA(nn.Module):
|
||||
if seqlen is not None:
|
||||
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||
rotary_max_seqlen = (
|
||||
inference_params.max_sequene_len if inference_params is not None else None
|
||||
)
|
||||
if self.num_heads_kv == self.num_heads:
|
||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||
if (
|
||||
@ -859,7 +869,9 @@ class ParallelMHA(nn.Module):
|
||||
or not inference_params.fused_ft_kernel
|
||||
):
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||
qkv = self.rotary_emb(
|
||||
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||
)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_attn(qkv, **kwargs)
|
||||
@ -889,7 +901,9 @@ class ParallelMHA(nn.Module):
|
||||
or not inference_params.fused_ft_kernel
|
||||
):
|
||||
if self.rotary_emb_dim > 0:
|
||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||
q, kv = self.rotary_emb(
|
||||
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||
)
|
||||
if inference_params is None:
|
||||
if not self.checkpointing:
|
||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user