Commit Graph

181 Commits

Author SHA1 Message Date
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q

* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751 Import torch before flash_attn_2_cuda 2023-08-19 21:07:33 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70 [ViT] Run black on vit.py 2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d [GPT] Run black on gpt.py 2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374 [MHA] Run black on mha.py 2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41 [FusedDense] Allow Row/ColumnParallelLinear to have uneven split 2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751 [FusedDense] Run black on fused_dense.py 2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
c60851a825 Bump to v2.0.7 2023-08-14 14:55:35 -07:00
Tri Dao
f8dccfc90a [CI] Fix MATRIX_CUDA_VERSION check 2023-08-14 10:27:26 -07:00
Tri Dao
9c531bdc0a Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI 2023-08-14 10:03:31 -07:00
Tri Dao
67ae6fd74b Bump to v2.0.6 2023-08-13 16:52:48 -07:00
Tri Dao
c5e87b11e9 Bump to v2.0.5 2023-08-13 13:55:04 -07:00
Tri Dao
364a5b4a71 [MLP] Change the check for out_features being None 2023-08-10 00:04:38 -07:00
Tri Dao
d30f2e1cd5 Bump to v2.0.4 2023-08-01 09:01:07 -07:00
Tri Dao
a4e5d1eddd Bump to v2.0.3 2023-07-31 17:49:23 -07:00
Tri Dao
8f4cd4c16b [Docs] Fix docstring about Q nheads being divisible by KV nheads 2023-07-31 17:47:03 -07:00
Tri Dao
184b992dcb [GPT] Implement parallel LLaMa 2023-07-28 15:52:48 -10:00
Tri Dao
840f7925a0 [Docs] Fix mention of MQA/GQA in qkvpacked functions 2023-07-28 12:26:29 -10:00
Tri Dao
60499abcfd [Benchmark] Add script to benchmark FlashAttention 2023-07-28 00:26:52 -10:00
Kirthi Shankar Sivamani
32a953f486
Request for v2.0.2 (#388)
* Bump version to 2.0.2

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Update version in Dockerfile

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-28 02:46:03 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs (#386)
* Add RNG state to kernel launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Save seed and offset for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Single thread write to global mem

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1colblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* compute_dq_dk_dv_1rowblock get seed and offset from launch params

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change forward c++ APIs to save RNG state for backward

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change backward c++ APIs to set RNG state for bprop launcher

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python side API changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Bug fix; only save seeds instead of full offset

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Account for 3D grid size

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4c98d0b41f [MLP] Edit ParallelGatedMlp 2023-07-26 09:39:37 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
b252072409 Bump to v2.0.1 2023-07-23 12:33:42 -10:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert (#363) 2023-07-23 00:21:45 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a [Rotary] Don't store inv_freq in state_dict 2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407 [MLP] Add ParallelMLP 2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2 [Block] Re-enable DropPath 2023-07-21 16:39:23 -07:00
Tri Dao
b4cc152e97 Make sure dout is contiguous 2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c Bump to v1.0.9 2023-07-17 03:16:40 -07:00
Volodymyr Kyrylov
70ab266a56 rotary: update cos/sin cache when switching from inference mode
This resolves RuntimeErrors after running evaluation in inference mode:

```
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
    qkv = self.rotary_emb(qkv)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
    return apply_rotary_emb_qkv_(
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
d2f4324f4c [LayerNorm] Make sure memory addresses are aligned to 16 bytes 2023-07-04 14:53:12 -07:00
Tri Dao
e8a0b4acdd [Doc] Change total -> total_q 2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8 Bump to v1.0.8 2023-07-02 17:04:54 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
ljss
8e44c0eefb
Fix a bug 2023-06-02 13:46:19 +08:00
Tri Dao
85b51d61ee Bump version to 1.0.7 2023-05-30 14:18:44 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00