[Rotary] Customize base, support seqlen_offset

This commit is contained in:
Tri Dao 2022-11-17 11:43:36 -08:00
parent d6ef701aa9
commit 71f674ae23

View File

@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module):
"""
def __init__(self, dim_model: int, *_, **__):
def __init__(self, dim: int, base=10000, *_, **__):
super().__init__()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model))
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._seq_len_cached = 0
self._cos_cached = None
self._sin_cached = None
def _update_cos_sin_cache(self, x):
def _update_cos_sin_cache(self, x, seqlen_offset=0):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
"""
seqlen = x.shape[1]
seqlen = x.shape[1] + seqlen_offset
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device
@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module):
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
self._update_cos_sin_cache(qkv)
return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached)
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(qkv, seqlen_offset)
return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:],
self._sin_cached[seqlen_offset:])