Markus Krimmel
|
6bbc532388
|
fix: cast the alibi slopes to torch.float32 (#846)
|
2024-03-15 00:49:40 -07:00 |
|
Tri Dao
|
a190df011c
|
Add window_size option to ParallelMHA
|
2024-02-10 01:02:14 -08:00 |
|
Tri Dao
|
ef0ed10622
|
Add window_size option to MHA and GPT
|
2024-01-31 02:42:23 -08:00 |
|
Tri Dao
|
abbc131173
|
[LayerNorm] Switch from CUDA to Triton implementation
|
2024-01-05 00:31:17 -08:00 |
|
jiaxingli
|
386e391117
|
Fix: implement deterministic backward in mha (#748)
* fix deterministic
* fix deterministic
|
2024-01-02 18:13:56 -08:00 |
|
Tri Dao
|
3f7d5786ba
|
Pass alibi slopes to flash_attn_with_kvcache during generation
|
2023-12-24 20:31:59 -08:00 |
|
Tri Dao
|
c3b2196652
|
Add Alibi to MHA, test with Baichuan-13B
|
2023-12-21 22:49:55 -08:00 |
|
Tri Dao
|
0705d2718d
|
[Llama] Fix some tests, add tests for Llama 2 and CodeLlama
|
2023-09-20 23:36:46 -07:00 |
|
Tri Dao
|
e6a8026489
|
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
|
2023-09-19 22:20:22 -07:00 |
|
Tri Dao
|
dfe29f5e2b
|
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
|
2023-09-18 15:29:06 -07:00 |
|
Tri Dao
|
8a733cbd53
|
[Gen] Fix calling update_graph_cache in tests
|
2023-09-10 17:22:37 -07:00 |
|
Tri Dao
|
a86442f0f3
|
[Gen] Use flash_attn_with_kvcache in generation
|
2023-09-07 08:24:43 -07:00 |
|
Tri Dao
|
3557e0bb8f
|
[MLP] Implement SwiGLU with torch jiterator
|
2023-09-04 15:43:53 -07:00 |
|
Tri Dao
|
798858f9f1
|
Fix test_baichuan
|
2023-09-03 21:01:37 -07:00 |
|
Tri Dao
|
de2949f37d
|
[Rotary] Pass max_seqlen from mha.py to rotary during inference
|
2023-09-03 11:37:06 -07:00 |
|
dan_the_3rd
|
011ec323d6
|
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
|
2023-08-30 10:29:54 -07:00 |
|
Tri Dao
|
a2974e850a
|
Change causal for CrossAttention in mha.py to align to bottom right
|
2023-08-26 12:57:33 -07:00 |
|
Tri Dao
|
f1a73d0740
|
Run isort and black on python files
|
2023-08-18 14:22:11 -07:00 |
|
Xuechen Li
|
bb4cded17b
|
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
|
2023-08-18 14:10:35 -07:00 |
|
Tri Dao
|
bec5b3d374
|
[MHA] Run black on mha.py
|
2023-08-16 23:47:13 -07:00 |
|
Tri Dao
|
364a5b4a71
|
[MLP] Change the check for out_features being None
|
2023-08-10 00:04:38 -07:00 |
|
Tri Dao
|
4c98d0b41f
|
[MLP] Edit ParallelGatedMlp
|
2023-07-26 09:39:37 -10:00 |
|
Haodong Lyu
|
8ee62efca3
|
Implement ParallelGatedMlp (#251)
|
2023-07-26 12:14:15 -07:00 |
|
Tri Dao
|
425dbcb6c6
|
[MHA] Implement MQA/GQA
|
2023-07-23 00:06:58 -07:00 |
|
Tri Dao
|
75e334d407
|
[MLP] Add ParallelMLP
|
2023-07-22 23:45:51 -07:00 |
|
Tri Dao
|
b3177dfaf6
|
[GPT] Enable FlashAttention for GPT-J
|
2023-07-21 17:29:10 -07:00 |
|
Tri Dao
|
6fc1e07da2
|
[Block] Re-enable DropPath
|
2023-07-21 16:39:23 -07:00 |
|
Tri Dao
|
4f285b3547
|
FlashAttention-2 release
|
2023-07-17 06:21:34 -07:00 |
|
Tri Dao
|
62e9814466
|
[Rotary] Make sure frequency calculation is in fp32
|
2023-07-02 16:39:39 -07:00 |
|
ljss
|
8e44c0eefb
|
Fix a bug
|
2023-06-02 13:46:19 +08:00 |
|
Tri Dao
|
48bc6eacd6
|
[Gen] Add rotary base as an argument to FT attention kernel
|
2023-05-30 13:38:34 -07:00 |
|
Federico Berto
|
3889ba168b
|
[BugFix] cannot unpack non-iterable NoneType object
|
2023-05-07 03:07:30 +09:00 |
|
Tri Dao
|
ba2fe7f378
|
[Gen] Move allocate_inference_cache to within the model
|
2023-04-20 18:15:12 -07:00 |
|
Tri Dao
|
311d6606bf
|
[Gen] Fix FT kernel smem size, CG when batch size changed
|
2023-04-20 17:03:13 -07:00 |
|
Tri Dao
|
96d10f6545
|
Implement LLaMa
|
2023-04-18 21:51:35 -07:00 |
|
Tri Dao
|
b630aef53f
|
Implement GatedMlp
|
2023-04-18 03:37:14 -07:00 |
|
Tri Dao
|
ac3b684cdb
|
Have a separate nn.Dropout module in SelfAttention module
|
2023-04-17 22:34:05 -07:00 |
|
Tri Dao
|
605655bc66
|
[Gen] Fix FT kernel when using CG
|
2023-04-14 16:50:01 -07:00 |
|
Zhiyuan Chen
|
8c42415664
|
make mlp hidden_features defaults to 4*in_features
|
2023-04-13 11:08:21 +08:00 |
|
Tri Dao
|
393882bc08
|
[LayerNorm] Implement LN with parallel residual, support dim 8k
|
2023-03-31 14:23:45 -07:00 |
|
Tri Dao
|
f5d0fbd468
|
[FT] Fix FT's single query attention for bf16 hdim128 rotary
|
2023-03-28 21:27:00 -07:00 |
|
Tri Dao
|
4d87e4d875
|
Implement GPT-J
|
2023-03-22 16:16:58 -07:00 |
|
Tri Dao
|
88173a1aaf
|
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
|
2023-01-17 18:12:27 -08:00 |
|
Tri Dao
|
780e8eeabb
|
[ViT] Support timm checkpoint, add tests
|
2023-01-16 01:20:34 -08:00 |
|
Tri Dao
|
ef085cfcda
|
[ViT] Fix extra norm_0, use new LN order in Block
|
2023-01-15 22:58:56 -08:00 |
|
Tri Dao
|
ff34123bd4
|
Reorder LN in Block, support OPT
|
2023-01-15 22:14:31 -08:00 |
|
Tri Dao
|
7c2191542a
|
[Gen] Make generation work with Tensor Parallel
|
2023-01-15 11:34:27 -08:00 |
|
Tri Dao
|
0938298e4c
|
[Gen] Adjust shape of kv_cache when using FT
|
2023-01-07 17:27:54 -08:00 |
|
Tri Dao
|
11be742aa3
|
[Gen] Test generation with rotary embedding
|
2023-01-07 14:37:54 -08:00 |
|
Tri Dao
|
8d9674ed08
|
Merge pull request #102 from Lamikins/main
fixed cross attention typeerror
|
2023-01-07 13:56:20 -08:00 |
|