diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 059b72b..da1b049 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module): """ - def __init__(self, dim_model: int, *_, **__): + def __init__(self, dim: int, base=10000, *_, **__): super().__init__() # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._seq_len_cached = 0 self._cos_cached = None self._sin_cached = None - def _update_cos_sin_cache(self, x): + def _update_cos_sin_cache(self, x, seqlen_offset=0): """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim) """ - seqlen = x.shape[1] + seqlen = x.shape[1] + seqlen_offset # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if (seqlen > self._seq_len_cached or self._cos_cached.device != x.device @@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module): self._cos_cached = torch.cos(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype) - def forward(self, qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - self._update_cos_sin_cache(qkv) - return apply_rotary_emb_qkv_(qkv, self._cos_cached, self._sin_cached) + def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """ + seqlen_offset: can be used in generation where the qkv being passed in is only the last + token in the batch. + """ + self._update_cos_sin_cache(qkv, seqlen_offset) + return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], + self._sin_cached[seqlen_offset:])