Merge pull request #299 from proger/rotary-inference-mode
rotary: update cos/sin cache when switching from inference mode
This commit is contained in:
commit
72ad03eaa6
@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype):
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user