Commit Graph

553 Commits

Author SHA1 Message Date
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
56ccaff126 [GPT] Add LLaMa-13B to test 2023-07-26 07:22:22 -10:00
Tri Dao
8e9820a55b [Rotary] Fix tests when loading state dict with rotary inv_freqs 2023-07-26 07:16:33 -10:00
Tri Dao
b252072409 Bump to v2.0.1 2023-07-23 12:33:42 -10:00
Tri Dao
2a2a3c4bfd [LayerNorm] Add test for randomness 2023-07-23 12:31:55 -10:00
Joel Lamy-Poirier
767b71ccf0
Fix random state for dropout_layer_norm (#315) 2023-07-23 15:05:13 -07:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert (#363) 2023-07-23 00:21:45 -07:00
Ian Timmis
cbf982afa5
README syntax highlighting (#365)
* README syntax highlighting

Adds syntax highlighting to README

* Update README.md
2023-07-23 00:21:30 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a [Rotary] Don't store inv_freq in state_dict 2023-07-22 23:52:42 -07:00
Tri Dao
a157cc8c9b [FT] Implement MQA/GQA 2023-07-22 23:47:01 -07:00
Tri Dao
75e334d407 [MLP] Add ParallelMLP 2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2 [Block] Re-enable DropPath 2023-07-21 16:39:23 -07:00
Tri Dao
9ee0ff1d9b Fix using dO stride for O, which can cause memory error in bwd 2023-07-20 17:39:57 -07:00
Tri Dao
2dd87d0609
Merge pull request #360 from chuanli11/fix/dockerfile
remove checkout v2.0.0.post1 from dockerfile
2023-07-20 19:41:24 -04:00
chuanli11
30fd8c17d8 remove checkout v2.0.0.post1 from dockerfile 2023-07-20 16:40:15 +00:00
Tri Dao
b8020d73c9
Merge pull request #348 from eltociear/patch-2
[LayerNorm] Fix typo in ln_api.cpp
2023-07-19 17:25:37 -04:00
Ikko Eltociear Ashimine
dfc60f6b7d
[LayerNorm] Fix typo in ln_api.cpp
unintialized -> uninitialized
2023-07-20 01:16:16 +09:00
Tri Dao
31ae2488e6
Merge pull request #343 from danthe3rd/if_constexpr
Fix compile error with `BOOL_SWITCH`
2023-07-19 04:27:07 -04:00
danthe3rd
538d570c96 Fix compile error on MSVC
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
d1a3b52f17 Add instruction about limiting number of ninja jobs 2023-07-17 23:17:47 -07:00
Tri Dao
b4cc152e97 Make sure dout is contiguous 2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c Bump to v1.0.9 2023-07-17 03:16:40 -07:00
Tri Dao
01c40dacc4
Merge pull request #313 from philipturner/patch-1
Metal FlashAttention
2023-07-15 20:36:48 -04:00
Philip Turner
4dbcaa1443
Update usage.md 2023-07-15 08:40:46 -04:00
Philip Turner
905c13a2d9
Update usage.md 2023-07-15 01:55:43 -04:00
Philip Turner
6ababeb7db
Update usage.md 2023-07-15 01:34:24 -04:00
Tri Dao
72ad03eaa6
Merge pull request #299 from proger/rotary-inference-mode
rotary: update cos/sin cache when switching from inference mode
2023-07-08 12:16:51 -04:00
Volodymyr Kyrylov
70ab266a56 rotary: update cos/sin cache when switching from inference mode
This resolves RuntimeErrors after running evaluation in inference mode:

```
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
    qkv = self.rotary_emb(qkv)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
    return apply_rotary_emb_qkv_(
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
2800efc71f [FT] rotary_cos/sin should have batch_size dimension 2023-07-06 15:33:33 -07:00
Tri Dao
d2f4324f4c [LayerNorm] Make sure memory addresses are aligned to 16 bytes 2023-07-04 14:53:12 -07:00
Tri Dao
3a9bfd076f [FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim) 2023-07-03 09:41:04 -07:00
Tri Dao
e8a0b4acdd [Doc] Change total -> total_q 2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8 Bump to v1.0.8 2023-07-02 17:04:54 -07:00
Tri Dao
a5d8714c26 [Build] Remove pyproject.toml
I haven't found an easy way to add torch as a build dependency in
pyproject.toml.
If we add torch in pyproject.toml, for some setup it would download a
different version of Pytorch before building.
If we don't add torch, lots of users report they get error when installing.
2023-07-02 17:02:49 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
Pierce Freeman
9af165c389 Clean setup.py imports 2023-06-07 17:27:36 -07:00
Pierce Freeman
eb812c205b Remove builder project 2023-06-07 17:20:13 -07:00
Pierce Freeman
6c730dc8c6 Bump version 2023-06-07 17:07:14 -07:00
Pierce Freeman
494b2aa486 Add notes to github action workflow 2023-06-07 17:06:12 -07:00
Pierce Freeman
8d60c373e4 Add torch dependency to final build 2023-06-04 06:14:42 -07:00
Pierce Freeman
1848d0004f Exclude cuda erroring builds 2023-06-04 06:14:42 -07:00
Pierce Freeman
84009fcc66 Exclude additional disallowed matrix params 2023-06-04 06:14:42 -07:00
Pierce Freeman
ac543b0e8d Full version matrix 2023-06-04 06:14:42 -07:00
Pierce Freeman
a372e2be1b Add CUDA 11.7 2023-06-04 06:14:42 -07:00
Pierce Freeman
18e100d312 Release is actually unsupported 2023-06-02 19:01:44 -07:00
Pierce Freeman
061470ae58 echo OS version 2023-06-02 18:59:09 -07:00