Commit Graph

157 Commits

Author SHA1 Message Date
Tri Dao
184b992dcb [GPT] Implement parallel LLaMa 2023-07-28 15:52:48 -10:00
Tri Dao
840f7925a0 [Docs] Fix mention of MQA/GQA in qkvpacked functions 2023-07-28 12:26:29 -10:00
Tri Dao
60499abcfd [Benchmark] Add script to benchmark FlashAttention 2023-07-28 00:26:52 -10:00
Kirthi Shankar Sivamani
32a953f486
Request for v2.0.2 (#388)
* Bump version to 2.0.2

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update version in Dockerfile

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-28 02:46:03 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs (#386)
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4c98d0b41f [MLP] Edit ParallelGatedMlp 2023-07-26 09:39:37 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
b252072409 Bump to v2.0.1 2023-07-23 12:33:42 -10:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert (#363) 2023-07-23 00:21:45 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a [Rotary] Don't store inv_freq in state_dict 2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407 [MLP] Add ParallelMLP 2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2 [Block] Re-enable DropPath 2023-07-21 16:39:23 -07:00
Tri Dao
b4cc152e97 Make sure dout is contiguous 2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c Bump to v1.0.9 2023-07-17 03:16:40 -07:00
Volodymyr Kyrylov
70ab266a56 rotary: update cos/sin cache when switching from inference mode
This resolves RuntimeErrors after running evaluation in inference mode:

```
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
    qkv = self.rotary_emb(qkv)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
    return apply_rotary_emb_qkv_(
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
d2f4324f4c [LayerNorm] Make sure memory addresses are aligned to 16 bytes 2023-07-04 14:53:12 -07:00
Tri Dao
e8a0b4acdd [Doc] Change total -> total_q 2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8 Bump to v1.0.8 2023-07-02 17:04:54 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
ljss
8e44c0eefb
Fix a bug 2023-06-02 13:46:19 +08:00
Tri Dao
85b51d61ee Bump version to 1.0.7 2023-05-30 14:18:44 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Kirthi Shankar Sivamani
dd9c3a1fc2 bump to v1.0.6
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-26 17:44:10 -07:00
Max H. Gerlach
31f78a9814 Allow adding an optional local version to the package version 2023-05-19 17:27:41 +02:00
Federico Berto
69f5f7d0a2 [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:44 +09:00
Federico Berto
3889ba168b [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:30 +09:00
Tri Dao
a9a4b4e4f2 [LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm 2023-05-04 23:39:43 -07:00
Tri Dao
fcab93b43a [Gen] Minor tweak to allocate_inference_cache 2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378 [Gen] Move allocate_inference_cache to within the model 2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1 [GPT] Add option to only return the logit for the last token 2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf [Gen] Fix FT kernel smem size, CG when batch size changed 2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f Implement GatedMlp 2023-04-18 03:37:14 -07:00
Tri Dao
ac3b684cdb Have a separate nn.Dropout module in SelfAttention module 2023-04-17 22:34:05 -07:00
Kirthi Shankar Sivamani
a0997bc77c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-14 21:45:37 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Kirthi Shankar Sivamani
081c2b012a
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-13 19:36:45 -07:00
Tri Dao
1c9ef9b399 [Gen] Measure prompt processing + decoding time, not just decoding 2023-04-13 15:39:56 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-12 22:42:24 -07:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features 2023-04-13 11:08:21 +08:00
Kirthi Shankar Sivamani
31018c5fa0 Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
d6fc860573
Merge pull request #147 from ksivaman/add_deterministic_execution_option
Add option for deterministic execution
2023-03-31 17:32:50 -04:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Kirthi Shankar Sivamani
b6aa059bbf Add option for deterministic execution 2023-03-30 18:23:35 -07:00