Commit Graph

329 Commits

Author SHA1 Message Date
Federico Berto
69f5f7d0a2 [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:44 +09:00
Federico Berto
3889ba168b [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:30 +09:00
Tri Dao
a9a4b4e4f2 [LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm 2023-05-04 23:39:43 -07:00
Tri Dao
fcab93b43a [Gen] Minor tweak to allocate_inference_cache 2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378 [Gen] Move allocate_inference_cache to within the model 2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1 [GPT] Add option to only return the logit for the last token 2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf [Gen] Fix FT kernel smem size, CG when batch size changed 2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f Implement GatedMlp 2023-04-18 03:37:14 -07:00
Tri Dao
ac3b684cdb Have a separate nn.Dropout module in SelfAttention module 2023-04-17 22:34:05 -07:00
Kirthi Shankar Sivamani
a0997bc77c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-14 21:45:37 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Kirthi Shankar Sivamani
081c2b012a
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-13 19:36:45 -07:00
Tri Dao
1c9ef9b399 [Gen] Measure prompt processing + decoding time, not just decoding 2023-04-13 15:39:56 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-12 22:42:24 -07:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features 2023-04-13 11:08:21 +08:00
Kirthi Shankar Sivamani
31018c5fa0 Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
d6fc860573
Merge pull request #147 from ksivaman/add_deterministic_execution_option
Add option for deterministic execution
2023-03-31 17:32:50 -04:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Kirthi Shankar Sivamani
b6aa059bbf Add option for deterministic execution 2023-03-30 18:23:35 -07:00
Tri Dao
993d12448e Implement GPT-NeoX 2023-03-29 01:21:25 -07:00
Tri Dao
f5d0fbd468 [FT] Fix FT's single query attention for bf16 hdim128 rotary 2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00
Tri Dao
5d079fdd7a [Triton] Fix benchmark_causal, mention Triton version 2023-03-22 00:51:16 -07:00
Tri Dao
dc08ea1c33 Support H100 for other CUDA extensions 2023-03-15 16:59:27 -07:00
Vik Paruchuri
3165398074 Remove unused kwargs in flashattention 2023-03-15 10:36:19 -07:00
Tri Dao
e45a46a5b7 [Rotary] Implement GPT-J style (interleaved) rotary 2023-03-14 14:35:53 -07:00
Tri Dao
78b7a1dc18 [OPT] Load fp16 weights on CPU before moving to GPU 2023-01-22 17:01:32 -08:00
Tri Dao
eb33e587e9 [LayerNorm] Rename x1 -> residual 2023-01-19 13:07:27 -08:00
Tri Dao
f68d41ec77 [Gen] Add OPT to generation test 2023-01-17 19:59:06 -08:00
Tri Dao
88173a1aaf [FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP 2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb [ViT] Support timm checkpoint, add tests 2023-01-16 01:20:34 -08:00
Tri Dao
2ec7d3f72c
Merge pull request #105 from jamaliki/patch-1
Change default dropout value in documentation
2023-01-15 23:01:20 -08:00
Tri Dao
ef085cfcda [ViT] Fix extra norm_0, use new LN order in Block 2023-01-15 22:58:56 -08:00
Tri Dao
ff34123bd4 Reorder LN in Block, support OPT 2023-01-15 22:14:31 -08:00
Tri Dao
7c2191542a [Gen] Make generation work with Tensor Parallel 2023-01-15 11:34:27 -08:00
Kiarash Jamali
41cb909741
Change default dropout value in documentation
Documentation says default is 0.1, but the code has attention_dropout default at 0.0
2023-01-13 10:50:07 +00:00
Tri Dao
f95c2fc108 [Gen] Remove commented code 2023-01-07 19:06:39 -08:00
Tri Dao
b48599002a [Gen] Add timing option 2023-01-07 19:05:09 -08:00
Tri Dao
0938298e4c [Gen] Adjust shape of kv_cache when using FT 2023-01-07 17:27:54 -08:00
Tri Dao
e02fd588aa [Gen] Implement top-k and top-p sampling 2023-01-07 17:00:02 -08:00
Tri Dao
11be742aa3 [Gen] Test generation with rotary embedding 2023-01-07 14:37:54 -08:00
Tri Dao
8d9674ed08
Merge pull request #102 from Lamikins/main
fixed cross attention typeerror
2023-01-07 13:56:20 -08:00
Tri Dao
93383bd55b [TP] Implement TensorParallel without sequence parallel 2023-01-07 13:45:22 -08:00
Darius Lam
aec35fd67c fixed cross attention typeerror 2023-01-07 12:58:41 -08:00
Tri Dao
6738d9477d [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
Tri Dao
a668890fcd [Gen] Add option to run generation with FT attention kernel 2023-01-03 22:10:31 -08:00
Tri Dao
4cab4de5ea [TP] Put parallel embeddings in separate modules 2023-01-02 08:47:48 -08:00
Tri Dao
1ec09ebd90 [FusedDense] Limit matrix dims to 2M (instead of 64k) 2023-01-01 17:06:39 -08:00
Tri Dao
714c1b4f0f [Bert] Fix embedding layer norm before embedding dropout 2023-01-01 10:38:05 -08:00
Tri Dao
ef1ba918c6 [GPT] Refactor function to shard state_dict for TensorParallel 2023-01-01 00:09:33 -08:00
Tri Dao
65b4064b2a [FusedDense] Kick off input all_gather before weight dtype conversion 2022-12-31 22:47:34 -08:00
Tri Dao
85b8e3d334 [Docs] Mention that XPos's scale_base is recommended to be 512 2022-12-29 20:25:02 -08:00
Tri Dao
a6ec1782dc Bump to v0.2.6 2022-12-27 22:05:20 -08:00
Tri Dao
63670fd84a Implement generation for GPT 2022-12-27 21:01:50 -08:00
Tri Dao
9d797d8848 Support loading GPT2 weights from Huggingface 2022-12-27 11:22:48 -08:00
Tri Dao
c6ecd40a59 Tweak CrossEntropyLoss to take process_group in init 2022-12-27 10:47:43 -08:00
Tri Dao
b4018a5028 Implement Tensor Parallel for GPT model 2022-12-26 16:22:43 -08:00
Tri Dao
78225c5366 Implement Tensor Parallel for GPT2Embeddings 2022-12-25 14:29:53 -08:00
Tri Dao
a8cfe51551 Implement Tensor Parallel for transformer Block 2022-12-25 14:08:21 -08:00
Tri Dao
1e712ea8b0 Implement TensorParallel for MHA 2022-12-25 11:39:55 -08:00
Tri Dao
226a1b721d Implement TensorParallel for FusedDense and FusedDenseGeluDense 2022-12-24 11:48:56 -08:00
Tri Dao
dff68c2b22 Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00
Tri Dao
e68ebbe89a Simplify FusedDense 2022-12-22 21:25:31 -08:00
Tri Dao
496e4f528c Implement XPos (Sun et al.) 2022-12-21 14:17:58 -08:00
Tri Dao
13cdceb377 Implement last_layer_subset optimization for BERT 2022-12-19 22:18:46 -08:00
Tri Dao
5fb6df0e04 Implement BERT 2022-12-18 21:47:27 -08:00
Alexander Ploshkin
ee8984d2be add asserts for sin shape 2022-12-17 13:34:57 +04:00
Alexander Ploshkin
c7c66976cc fix slicing dimensions 2022-12-16 15:39:06 +04:00
Alexander Ploshkin
96656b9323 Remove redundant shape asserts in rotary embeddings 2022-12-15 18:13:21 +04:00
Tri Dao
6b5f271c6d [Triton] Avoid einops repeat by using Tensor.expand 2022-12-14 14:48:41 -08:00
Tri Dao
88c4e5dbf6 Fix the case when dout is not contiguous 2022-12-13 13:58:17 -08:00
Tri Dao
5db330519a [LayerNorm] Support taking subset of input or subset of output 2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
Tri Dao
1feb94265c [ViT] Use dropout_add_ln for the 1st layer norm 2022-11-23 12:48:56 -08:00
Tri Dao
b8ccd20098 [Triton] Fix variable name from qkv to kv (h/t FrankZijlstra) 2022-11-22 02:07:32 -08:00
Tri Dao
054816177e Bump version to 0.2.1 2022-11-20 22:35:59 -08:00
Tri Dao
0fa5c0d7ef Add PatchEmbed 2022-11-17 16:56:06 -08:00
Tri Dao
ece539abd6 Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
Tri Dao
71f674ae23 [Rotary] Customize base, support seqlen_offset 2022-11-17 11:43:36 -08:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00
Tri Dao
343492ec30 Make nccl operations async in CrossEntropyLossParallel 2022-11-13 17:27:26 -08:00
Tri Dao
7c9953815a Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9 Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
908a5b2244 Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty) 2022-11-07 08:58:16 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04 Implement rotary embedding in CUDA 2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff Fix more race condition in Triton bwd when there's bias 2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00
Tri Dao
aacc10fbab Fix race condition in Triton bwd for non-po2 headdims 2022-11-02 07:32:54 -07:00
Tri Dao
1fb12afdfb Avoid memcpy in the Triton bwd 2022-11-01 15:06:45 -07:00
Tri Dao
731f154de3 Fix race conditions in the Triton bwd for headdim=64 2022-11-01 15:05:55 -07:00
Tri Dao
9b0bc97872 Fix race condition in Triton fwd 2022-10-31 14:34:57 -07:00
Tri Dao
215930bce3 Fix EVEN_M & EVEN_HEADDIM for headdim=40 in Triton bwd 2022-10-31 01:41:49 -07:00
Tri Dao
4f81aff46e Add debug_barrier for all headdims in Triton bwd 2022-10-31 01:25:02 -07:00
Tri Dao
bedcbd6a71 Disable some autotune configs that give wrong results in Triton bwd 2022-10-31 01:05:51 -07:00
Tri Dao
e78d509c64 [WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
2022-10-31 00:46:22 -07:00
Tri Dao
008951f1d9 Support all head dimensions up to 128 in the Triton fwd 2022-10-30 22:10:48 -07:00
Tri Dao
b910bf14c1 Support arbitrary seqlens (both q & k) in Triton bwd 2022-10-30 21:50:53 -07:00
Tri Dao
dc55469355 Support arbitrary seqlen_k in Triton bwd 2022-10-30 21:26:26 -07:00
Tri Dao
d11341fd1a Fix Triton fwd to support seqlen not multiples of 128 2022-10-30 19:05:47 -07:00
Tri Dao
b0c0db81f6 Implement FlashAttention in Triton 2022-10-30 18:09:11 -07:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
ed553e9238 Add Megatron attention implementation for benchmarking 2022-10-23 23:04:16 -07:00
Tri Dao
50ca23488d Add Triton implementation for benchmarking 2022-10-23 17:25:56 -07:00
Tri Dao
fb88e5e4b3 Move benchmark utils, support AMP 2022-10-23 12:50:00 -07:00
Tri Dao
a5a8806d1a Split bwd on the seqlen_q dimension 2022-10-23 11:35:15 -07:00
Tri Dao
a44f48df5a Split fwd on the seqlen_q dimension 2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6 Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
1b9facacc3 Fix QKV interface to allocate output in Python 2022-10-14 03:33:41 -07:00
Tri Dao
5badfb7848 Implement attention kernel that splits the batch into two 2022-10-13 20:49:02 -07:00
Antoine Adam
4e38df059e
remove numpy dependency
According to the `setup.py` file, only dependencies are torch and einops. But the `bert_padding.py` file requires `numpy` only to multiply the elements of a `torch.Size` object. This change aims at allowing the use of FlashAttention without numpy.
2022-10-06 19:17:15 +02:00
Tri Dao
13403e8115 Relax assert to allow both bf16 and fp16 2022-09-11 12:09:43 -07:00
eric-tc-wong
b410d14f28
Update flash_attention.py
Recasting query and key after rotary_emb()
2022-09-06 17:29:49 -04:00
Tri Dao
19d1261025 Add back need_weights in FlashMHA 2022-08-09 10:14:10 -07:00
Tri Dao
6cc7342575 Support index_first_axis with more than 2 dimensions 2022-08-05 09:48:16 -07:00
Tri Dao
713ea302d7 Allow headdim 128 in FlashMHA interface 2022-08-05 09:47:22 -07:00
Tri Dao
a5559a0e75 Do P * dP (pointwise) in the bwd in fp32 instead of fp16 2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Gustaf
af4a9ce024 Add missing __init__.py 2022-07-03 02:04:55 -04:00
Tri Dao
5a61cb7729 Rename src -> flash_attn 2022-06-01 18:50:26 -07:00