Tri Dao
31ae2488e6
Merge pull request #343 from danthe3rd/if_constexpr
...
Fix compile error with `BOOL_SWITCH`
2023-07-19 04:27:07 -04:00
danthe3rd
538d570c96
Fix compile error on MSVC
...
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
d1a3b52f17
Add instruction about limiting number of ninja jobs
2023-07-17 23:17:47 -07:00
Tri Dao
b4cc152e97
Make sure dout is contiguous
2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c
Bump to v1.0.9
2023-07-17 03:16:40 -07:00
Tri Dao
01c40dacc4
Merge pull request #313 from philipturner/patch-1
...
Metal FlashAttention
2023-07-15 20:36:48 -04:00
Philip Turner
4dbcaa1443
Update usage.md
2023-07-15 08:40:46 -04:00
Philip Turner
905c13a2d9
Update usage.md
2023-07-15 01:55:43 -04:00
Philip Turner
6ababeb7db
Update usage.md
2023-07-15 01:34:24 -04:00
Tri Dao
72ad03eaa6
Merge pull request #299 from proger/rotary-inference-mode
...
rotary: update cos/sin cache when switching from inference mode
2023-07-08 12:16:51 -04:00
Volodymyr Kyrylov
70ab266a56
rotary: update cos/sin cache when switching from inference mode
...
This resolves RuntimeErrors after running evaluation in inference mode:
```
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
qkv = self.rotary_emb(qkv)
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
return apply_rotary_emb_qkv_(
File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
2800efc71f
[FT] rotary_cos/sin should have batch_size dimension
2023-07-06 15:33:33 -07:00
Tri Dao
d2f4324f4c
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
2023-07-04 14:53:12 -07:00
Tri Dao
3a9bfd076f
[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)
2023-07-03 09:41:04 -07:00
Tri Dao
e8a0b4acdd
[Doc] Change total -> total_q
2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8
Bump to v1.0.8
2023-07-02 17:04:54 -07:00
Tri Dao
a5d8714c26
[Build] Remove pyproject.toml
...
I haven't found an easy way to add torch as a build dependency in
pyproject.toml.
If we add torch in pyproject.toml, for some setup it would download a
different version of Pytorch before building.
If we don't add torch, lots of users report they get error when installing.
2023-07-02 17:02:49 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
Tri Dao
9818f85fee
Merge pull request #255 from beginlner/main
...
Fix a bug
2023-06-02 02:23:25 -04:00
ljss
8e44c0eefb
Fix a bug
2023-06-02 13:46:19 +08:00
Tri Dao
85b51d61ee
Bump version to 1.0.7
2023-05-30 14:18:44 -07:00
Tri Dao
27f8f890df
[FusedDense] Allocate lt_workspace on input device
2023-05-30 14:17:26 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Tri Dao
7c766b1bbc
Merge pull request #243 from ksivaman/bump_version_to_v1_0_6
...
bump to v1.0.6
2023-05-26 22:48:08 -04:00
Kirthi Shankar Sivamani
dd9c3a1fc2
bump to v1.0.6
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-26 17:44:10 -07:00
Tri Dao
ce68305c84
Update installation instruction
2023-05-25 16:52:52 -07:00
Tri Dao
cf4f0a39f3
Merge pull request #241 from ksivaman/fix_compilation_time
...
Fix compilation time
2023-05-25 18:34:41 -04:00
Kirthi Shankar Sivamani
6d45d0bd6c
Re-add ninja
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-25 21:22:50 +00:00
Kirthi Shankar Sivamani
852bc40b8c
Remove torch from pyproject.toml
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-25 19:12:22 +00:00
Kirthi Shankar Sivamani
c1d117c2d0
Remove ninja from pyproject.toml
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-25 19:12:00 +00:00
Tri Dao
f0c40b7ddb
Recommend Nvidia's Pytorch container
2023-05-19 09:41:14 -07:00
Tri Dao
3cad2ab35d
Merge pull request #229 from maxhgerlach/local-version
...
Allow adding an optional local version to the package version
2023-05-19 11:43:24 -04:00
Max H. Gerlach
31f78a9814
Allow adding an optional local version to the package version
2023-05-19 17:27:41 +02:00
Tri Dao
40a25c8ee7
Update roadmap
2023-05-17 08:32:26 -07:00
Tri Dao
eff9fe6b80
Add ninja to pyproject.toml build-system, bump to v1.0.5
2023-05-12 14:20:31 -07:00
Tri Dao
36d0a19f1e
Merge pull request #193 from anthonyhu/pyproject-build
...
Use pyproject.toml to specify build dependencies
2023-05-11 21:26:28 -04:00
Tri Dao
5bf7f57d47
Merge pull request #202 from fedebotu/main
...
[BugFix] avoid bug on ImportError
2023-05-06 14:15:02 -04:00
Federico Berto
69f5f7d0a2
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:44 +09:00
Federico Berto
3889ba168b
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:30 +09:00
Tri Dao
a9a4b4e4f2
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
2023-05-04 23:39:43 -07:00
Anthony Hu
d63cfc3551
Use pyproject.toml to specify build dependencies
2023-04-27 11:51:52 +01:00
Tri Dao
ad113948a6
[Docs] Clearer error message for bwd d > 64, bump to v1.0.4
2023-04-26 09:19:48 -07:00
Tri Dao
fbbb107848
Bump version to v1.0.3.post0
2023-04-21 13:37:23 -07:00
Tri Dao
67ef5d28df
Bump version to 1.0.3
2023-04-21 12:04:53 -07:00
Tri Dao
fcab93b43a
[Gen] Minor tweak to allocate_inference_cache
2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378
[Gen] Move allocate_inference_cache to within the model
2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1
[GPT] Add option to only return the logit for the last token
2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf
[Gen] Fix FT kernel smem size, CG when batch size changed
2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00