Commit Graph

12 Commits

Author SHA1 Message Date
Volodymyr Kyrylov
70ab266a56 rotary: update cos/sin cache when switching from inference mode
This resolves RuntimeErrors after running evaluation in inference mode:

```
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
    qkv = self.rotary_emb(qkv)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
    return apply_rotary_emb_qkv_(
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Tri Dao
e45a46a5b7 [Rotary] Implement GPT-J style (interleaved) rotary 2023-03-14 14:35:53 -07:00
Tri Dao
85b8e3d334 [Docs] Mention that XPos's scale_base is recommended to be 512 2022-12-29 20:25:02 -08:00
Tri Dao
1e712ea8b0 Implement TensorParallel for MHA 2022-12-25 11:39:55 -08:00
Tri Dao
496e4f528c Implement XPos (Sun et al.) 2022-12-21 14:17:58 -08:00
Alexander Ploshkin
ee8984d2be add asserts for sin shape 2022-12-17 13:34:57 +04:00
Alexander Ploshkin
c7c66976cc fix slicing dimensions 2022-12-16 15:39:06 +04:00
Alexander Ploshkin
96656b9323 Remove redundant shape asserts in rotary embeddings 2022-12-15 18:13:21 +04:00
Tri Dao
71f674ae23 [Rotary] Customize base, support seqlen_offset 2022-11-17 11:43:36 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00