[Rotary] Pass max_seqlen from mha.py to rotary during inference

This commit is contained in:
Tri Dao 2023-09-03 11:36:49 -07:00
parent 942fcbf046
commit de2949f37d
2 changed files with 21 additions and 7 deletions

View File

@ -402,10 +402,10 @@ class RotaryEmbedding(torch.nn.Module):
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen = qkv.shape[1]
if isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
elif max_seqlen is not None:
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
elif isinstance(seqlen_offset, int):
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
if kv is None:
if self.scale is None:
return apply_rotary_emb_qkv_(

View File

@ -606,6 +606,9 @@ class MHA(nn.Module):
else {"key_padding_mask": key_padding_mask, **kwargs}
)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
)
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
if not self.return_residual:
@ -623,7 +626,9 @@ class MHA(nn.Module):
or not inference_params.fused_ft_kernel
):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
@ -669,7 +674,9 @@ class MHA(nn.Module):
or not inference_params.fused_ft_kernel
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)
@ -851,6 +858,9 @@ class ParallelMHA(nn.Module):
if seqlen is not None:
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
)
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
if (
@ -859,7 +869,9 @@ class ParallelMHA(nn.Module):
or not inference_params.fused_ft_kernel
):
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
qkv = self.rotary_emb(
qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
@ -889,7 +901,9 @@ class ParallelMHA(nn.Module):
or not inference_params.fused_ft_kernel
):
if self.rotary_emb_dim > 0:
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
q, kv = self.rotary_emb(
q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen
)
if inference_params is None:
if not self.checkpointing:
context = self.inner_cross_attn(q, kv, **kwargs)