Commit Graph

117 Commits

Author SHA1 Message Date
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac [CE] Implement CrossEntropyLoss in Triton 2023-09-15 20:05:28 -07:00
Tri Dao
56b7fc6ee0 Simplify the implementation of KVcache attn by appending KV first 2023-09-13 15:55:48 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Kevin Hu
07005806ff
Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT (#527) 2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3 [Gen] Use flash_attn_with_kvcache in generation 2023-09-07 08:24:43 -07:00
Tri Dao
9795159082 [Rotary] Set device before launching Triton kernel to avoid error 2023-09-05 21:29:03 -07:00
Tri Dao
fd20f16a4e Support cache_seqlens being integer 2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
37c6e05406 Implement flash_attn_with_kvcache 2023-09-04 00:11:44 -07:00
Tri Dao
0c04943fa2 Require CUDA 11.6+, clean up setup.py 2023-09-03 21:24:56 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
b28ec236df [Rotary] Implement varlen rotary 2023-09-03 17:57:10 -07:00
Tri Dao
1c523c1ce1 [Rotary] Speed up rotary kernel when interleaved=True 2023-09-03 16:24:37 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
371e20658c [GPT] Test generation when passing in multiple tokens 2023-08-26 13:56:41 -07:00
Tri Dao
c000c3a2c0 [GPT] Move more tests to test_gpt.py 2023-08-26 13:00:40 -07:00
Tri Dao
9b713872ea [GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py 2023-08-26 12:55:02 -07:00
Tri Dao
73bd3f3bbb Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Tri Dao
0e8c46ae08 Run isort and black on test files 2023-08-18 20:59:35 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
3524e13c11 Update to Cutlass 3.1 2023-08-13 13:53:17 -07:00
Tri Dao
1c41d2b0e5 Fix race condition in bwd (overwriting sK) 2023-08-01 09:00:10 -07:00
Tri Dao
a4f148b6ab Fix masking of bwd when seqlen is not divisible by 128 2023-07-31 17:46:34 -07:00
Tri Dao
184b992dcb [GPT] Implement parallel LLaMa 2023-07-28 15:52:48 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
56ccaff126 [GPT] Add LLaMa-13B to test 2023-07-26 07:22:22 -10:00
Tri Dao
8e9820a55b [Rotary] Fix tests when loading state dict with rotary inv_freqs 2023-07-26 07:16:33 -10:00
Tri Dao
2a2a3c4bfd [LayerNorm] Add test for randomness 2023-07-23 12:31:55 -10:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
d2f4324f4c [LayerNorm] Make sure memory addresses are aligned to 16 bytes 2023-07-04 14:53:12 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Tri Dao
a9a4b4e4f2 [LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm 2023-05-04 23:39:43 -07:00