[Rotary] Pass max_seqlen from mha.py to rotary during inference
This commit is contained in:
parent
942fcbf046
commit
de2949f37d
@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module):
|
|||||||
Apply rotary embedding *inplace* to qkv and / or kv.
|
Apply rotary embedding *inplace* to qkv and / or kv.
|
||||||
"""
|
"""
|
||||||
seqlen = qkv.shape[1]
|
seqlen = qkv.shape[1]
|
||||||
if isinstance(seqlen_offset, int):
|
if max_seqlen is not None:
|
||||||
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
||||||
elif max_seqlen is not None:
|
|
||||||
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
||||||
|
elif isinstance(seqlen_offset, int):
|
||||||
|
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
||||||
if kv is None:
|
if kv is None:
|
||||||
if self.scale is None:
|
if self.scale is None:
|
||||||
return apply_rotary_emb_qkv_(
|
return apply_rotary_emb_qkv_(
|
||||||
|
|||||||
@ -606,6 +606,9 @@ class MHA(nn.Module):
|
|||||||
else {"key_padding_mask": key_padding_mask, **kwargs}
|
else {"key_padding_mask": key_padding_mask, **kwargs}
|
||||||
)
|
)
|
||||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||||
|
rotary_max_seqlen = (
|
||||||
|
inference_params.max_sequene_len if inference_params is not None else None
|
||||||
|
)
|
||||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||||
assert x_kv is None and mixer_subset is None
|
assert x_kv is None and mixer_subset is None
|
||||||
if not self.return_residual:
|
if not self.return_residual:
|
||||||
@ -623,7 +626,9 @@ class MHA(nn.Module):
|
|||||||
or not inference_params.fused_ft_kernel
|
or not inference_params.fused_ft_kernel
|
||||||
):
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
qkv = self.rotary_emb(
|
||||||
|
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||||
|
)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_attn(qkv, **kwargs)
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
@ -669,7 +674,9 @@ class MHA(nn.Module):
|
|||||||
or not inference_params.fused_ft_kernel
|
or not inference_params.fused_ft_kernel
|
||||||
):
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
q, kv = self.rotary_emb(
|
||||||
|
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||||
|
)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||||
@ -851,6 +858,9 @@ class ParallelMHA(nn.Module):
|
|||||||
if seqlen is not None:
|
if seqlen is not None:
|
||||||
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
||||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||||
|
rotary_max_seqlen = (
|
||||||
|
inference_params.max_sequene_len if inference_params is not None else None
|
||||||
|
)
|
||||||
if self.num_heads_kv == self.num_heads:
|
if self.num_heads_kv == self.num_heads:
|
||||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||||
if (
|
if (
|
||||||
@ -859,7 +869,9 @@ class ParallelMHA(nn.Module):
|
|||||||
or not inference_params.fused_ft_kernel
|
or not inference_params.fused_ft_kernel
|
||||||
):
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
qkv = self.rotary_emb(
|
||||||
|
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||||
|
)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_attn(qkv, **kwargs)
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
@ -889,7 +901,9 @@ class ParallelMHA(nn.Module):
|
|||||||
or not inference_params.fused_ft_kernel
|
or not inference_params.fused_ft_kernel
|
||||||
):
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
q, kv = self.rotary_emb(
|
||||||
|
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
|
||||||
|
)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user