rotary: update cos/sin cache when switching from inference mode
This resolves RuntimeErrors after running evaluation in inference mode:
```
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
qkv = self.rotary_emb(qkv)
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
return apply_rotary_emb_qkv_(
File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
This commit is contained in:
parent
2800efc71f
commit
70ab266a56
@ -211,9 +211,11 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
# if we're on a new device (possibly due to tracing for instance),
|
||||
# or if we're switching from inference mode to training
|
||||
if (seqlen > self._seq_len_cached or self._cos_cached.device != device
|
||||
or self._cos_cached.dtype != dtype):
|
||||
or self._cos_cached.dtype != dtype
|
||||
or (self.training and self._cos_cached.is_inference())):
|
||||
self._seq_len_cached = seqlen
|
||||
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
|
||||
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user