Commit Graph

68 Commits

Author SHA1 Message Date
Markus Krimmel
6bbc532388
fix: cast the alibi slopes to torch.float32 (#846) 2024-03-15 00:49:40 -07:00
Tri Dao
a190df011c Add window_size option to ParallelMHA 2024-02-10 01:02:14 -08:00
Tri Dao
ef0ed10622 Add window_size option to MHA and GPT 2024-01-31 02:42:23 -08:00
Tri Dao
abbc131173 [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
jiaxingli
386e391117
Fix: implement deterministic backward in mha (#748)
* fix deterministic

* fix deterministic
2024-01-02 18:13:56 -08:00
Tri Dao
3f7d5786ba Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
Tri Dao
c3b2196652 Add Alibi to MHA, test with Baichuan-13B 2023-12-21 22:49:55 -08:00
Tri Dao
0705d2718d [Llama] Fix some tests, add tests for Llama 2 and CodeLlama 2023-09-20 23:36:46 -07:00
Tri Dao
e6a8026489 [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset 2023-09-19 22:20:22 -07:00
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Tri Dao
a86442f0f3 [Gen] Use flash_attn_with_kvcache in generation 2023-09-07 08:24:43 -07:00
Tri Dao
3557e0bb8f [MLP] Implement SwiGLU with torch jiterator 2023-09-04 15:43:53 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
de2949f37d [Rotary] Pass max_seqlen from mha.py to rotary during inference 2023-09-03 11:37:06 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
Tri Dao
a2974e850a Change causal for CrossAttention in mha.py to align to bottom right 2023-08-26 12:57:33 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
bec5b3d374 [MHA] Run black on mha.py 2023-08-16 23:47:13 -07:00
Tri Dao
364a5b4a71 [MLP] Change the check for out_features being None 2023-08-10 00:04:38 -07:00
Tri Dao
4c98d0b41f [MLP] Edit ParallelGatedMlp 2023-07-26 09:39:37 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
75e334d407 [MLP] Add ParallelMLP 2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2 [Block] Re-enable DropPath 2023-07-21 16:39:23 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
62e9814466 [Rotary] Make sure frequency calculation is in fp32 2023-07-02 16:39:39 -07:00
ljss
8e44c0eefb
Fix a bug 2023-06-02 13:46:19 +08:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Federico Berto
3889ba168b [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:30 +09:00
Tri Dao
ba2fe7f378 [Gen] Move allocate_inference_cache to within the model 2023-04-20 18:15:12 -07:00
Tri Dao
311d6606bf [Gen] Fix FT kernel smem size, CG when batch size changed 2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f Implement GatedMlp 2023-04-18 03:37:14 -07:00
Tri Dao
ac3b684cdb Have a separate nn.Dropout module in SelfAttention module 2023-04-17 22:34:05 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features 2023-04-13 11:08:21 +08:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Tri Dao
f5d0fbd468 [FT] Fix FT's single query attention for bf16 hdim128 rotary 2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00
Tri Dao
88173a1aaf [FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP 2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb [ViT] Support timm checkpoint, add tests 2023-01-16 01:20:34 -08:00
Tri Dao
ef085cfcda [ViT] Fix extra norm_0, use new LN order in Block 2023-01-15 22:58:56 -08:00
Tri Dao
ff34123bd4 Reorder LN in Block, support OPT 2023-01-15 22:14:31 -08:00
Tri Dao
7c2191542a [Gen] Make generation work with Tensor Parallel 2023-01-15 11:34:27 -08:00
Tri Dao
0938298e4c [Gen] Adjust shape of kv_cache when using FT 2023-01-07 17:27:54 -08:00
Tri Dao
11be742aa3 [Gen] Test generation with rotary embedding 2023-01-07 14:37:54 -08:00
Tri Dao
8d9674ed08
Merge pull request #102 from Lamikins/main
fixed cross attention typeerror
2023-01-07 13:56:20 -08:00