Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp ( #251 )
2023-07-26 12:14:15 -07:00
Tri Dao
56ccaff126
[GPT] Add LLaMa-13B to test
2023-07-26 07:22:22 -10:00
Tri Dao
8e9820a55b
[Rotary] Fix tests when loading state dict with rotary inv_freqs
2023-07-26 07:16:33 -10:00
Tri Dao
b252072409
Bump to v2.0.1
2023-07-23 12:33:42 -10:00
Tri Dao
2a2a3c4bfd
[LayerNorm] Add test for randomness
2023-07-23 12:31:55 -10:00
Joel Lamy-Poirier
767b71ccf0
Fix random state for dropout_layer_norm ( #315 )
2023-07-23 15:05:13 -07:00
Tri Dao
d38357dd2f
[GPT] Implement Falcon
2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert ( #363 )
2023-07-23 00:21:45 -07:00
Ian Timmis
cbf982afa5
README syntax highlighting ( #365 )
...
* README syntax highlighting
Adds syntax highlighting to README
* Update README.md
2023-07-23 00:21:30 -07:00
Tri Dao
425dbcb6c6
[MHA] Implement MQA/GQA
2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a
[Rotary] Don't store inv_freq in state_dict
2023-07-22 23:52:42 -07:00
Tri Dao
a157cc8c9b
[FT] Implement MQA/GQA
2023-07-22 23:47:01 -07:00
Tri Dao
75e334d407
[MLP] Add ParallelMLP
2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6
[GPT] Enable FlashAttention for GPT-J
2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2
[Block] Re-enable DropPath
2023-07-21 16:39:23 -07:00
Tri Dao
9ee0ff1d9b
Fix using dO stride for O, which can cause memory error in bwd
2023-07-20 17:39:57 -07:00
Tri Dao
2dd87d0609
Merge pull request #360 from chuanli11/fix/dockerfile
...
remove checkout v2.0.0.post1 from dockerfile
2023-07-20 19:41:24 -04:00
chuanli11
30fd8c17d8
remove checkout v2.0.0.post1 from dockerfile
2023-07-20 16:40:15 +00:00
Tri Dao
b8020d73c9
Merge pull request #348 from eltociear/patch-2
...
[LayerNorm] Fix typo in ln_api.cpp
2023-07-19 17:25:37 -04:00
Ikko Eltociear Ashimine
dfc60f6b7d
[LayerNorm] Fix typo in ln_api.cpp
...
unintialized -> uninitialized
2023-07-20 01:16:16 +09:00
Tri Dao
31ae2488e6
Merge pull request #343 from danthe3rd/if_constexpr
...
Fix compile error with `BOOL_SWITCH`
2023-07-19 04:27:07 -04:00
danthe3rd
538d570c96
Fix compile error on MSVC
...
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
d1a3b52f17
Add instruction about limiting number of ninja jobs
2023-07-17 23:17:47 -07:00
Tri Dao
b4cc152e97
Make sure dout is contiguous
2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c
Bump to v1.0.9
2023-07-17 03:16:40 -07:00
Tri Dao
01c40dacc4
Merge pull request #313 from philipturner/patch-1
...
Metal FlashAttention
2023-07-15 20:36:48 -04:00
Philip Turner
4dbcaa1443
Update usage.md
2023-07-15 08:40:46 -04:00
Philip Turner
905c13a2d9
Update usage.md
2023-07-15 01:55:43 -04:00
Philip Turner
6ababeb7db
Update usage.md
2023-07-15 01:34:24 -04:00
Tri Dao
72ad03eaa6
Merge pull request #299 from proger/rotary-inference-mode
...
rotary: update cos/sin cache when switching from inference mode
2023-07-08 12:16:51 -04:00
Volodymyr Kyrylov
70ab266a56
rotary: update cos/sin cache when switching from inference mode
...
This resolves RuntimeErrors after running evaluation in inference mode:
```
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
qkv = self.rotary_emb(qkv)
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
return apply_rotary_emb_qkv_(
File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
2800efc71f
[FT] rotary_cos/sin should have batch_size dimension
2023-07-06 15:33:33 -07:00
Tri Dao
d2f4324f4c
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
2023-07-04 14:53:12 -07:00
Tri Dao
3a9bfd076f
[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)
2023-07-03 09:41:04 -07:00
Tri Dao
e8a0b4acdd
[Doc] Change total -> total_q
2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8
Bump to v1.0.8
2023-07-02 17:04:54 -07:00
Tri Dao
a5d8714c26
[Build] Remove pyproject.toml
...
I haven't found an easy way to add torch as a build dependency in
pyproject.toml.
If we add torch in pyproject.toml, for some setup it would download a
different version of Pytorch before building.
If we don't add torch, lots of users report they get error when installing.
2023-07-02 17:02:49 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
Pierce Freeman
9af165c389
Clean setup.py imports
2023-06-07 17:27:36 -07:00
Pierce Freeman
eb812c205b
Remove builder project
2023-06-07 17:20:13 -07:00
Pierce Freeman
6c730dc8c6
Bump version
2023-06-07 17:07:14 -07:00
Pierce Freeman
494b2aa486
Add notes to github action workflow
2023-06-07 17:06:12 -07:00
Pierce Freeman
8d60c373e4
Add torch dependency to final build
2023-06-04 06:14:42 -07:00
Pierce Freeman
1848d0004f
Exclude cuda erroring builds
2023-06-04 06:14:42 -07:00
Pierce Freeman
84009fcc66
Exclude additional disallowed matrix params
2023-06-04 06:14:42 -07:00
Pierce Freeman
ac543b0e8d
Full version matrix
2023-06-04 06:14:42 -07:00
Pierce Freeman
a372e2be1b
Add CUDA 11.7
2023-06-04 06:14:42 -07:00
Pierce Freeman
18e100d312
Release is actually unsupported
2023-06-02 19:01:44 -07:00
Pierce Freeman
061470ae58
echo OS version
2023-06-02 18:59:09 -07:00