[Rotary] Customize base, support seqlen_offset
This commit is contained in:
parent
d6ef701aa9
commit
71f674ae23
@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dim_model: int, *_, **__):
|
||||
def __init__(self, dim: int, base=10000, *_, **__):
|
||||
super().__init__()
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
self._seq_len_cached = 0
|
||||
self._cos_cached = None
|
||||
self._sin_cached = None
|
||||
|
||||
def _update_cos_sin_cache(self, x):
|
||||
def _update_cos_sin_cache(self, x, seqlen_offset=0):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
|
||||
"""
|
||||
seqlen = x.shape[1]
|
||||
seqlen = x.shape[1] + seqlen_offset
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
|
||||
@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
|
||||
def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self._update_cos_sin_cache(qkv)
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached)
|
||||
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
seqlen_offset: can be used in generation where the qkv being passed in is only the last
|
||||
token in the batch.
|
||||
"""
|
||||
self._update_cos_sin_cache(qkv, seqlen_offset)
|
||||
return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:],
|
||||
self._sin_cached[seqlen_offset:])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user