diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 0cb5632..bfda386 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module): def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): # Reset the tables if the sequence length has changed, - # or if we're on a new device (possibly due to tracing for instance) + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training if (seqlen > self._seq_len_cached or self._cos_cached.device != device - or self._cos_cached.dtype != dtype): + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference())): self._seq_len_cached = seqlen # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 # And the output of arange can be quite large, so bf16 would lose a lot of precision.