Merge pull request #299 from proger/rotary-inference-mode

rotary: update cos/sin cache when switching from inference mode
This commit is contained in:
Tri Dao 2023-07-08 12:16:51 -04:00 committed by GitHub
commit 72ad03eaa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module):
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
or self._cos_cached.dtype != dtype):
or self._cos_cached.dtype != dtype
or (self.training and self._cos_cached.is_inference())):
self._seq_len_cached = seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.