Commit Graph

461 Commits

Author SHA1 Message Date
Tri Dao
ac3b684cdb Have a separate nn.Dropout module in SelfAttention module 2023-04-17 22:34:05 -07:00
Tri Dao
df1344f866 Bump to v1.0.2 2023-04-15 22:19:31 -07:00
Tri Dao
635f159ee3
Merge pull request #166 from ksivaman/enable_cuda_graph_capture
Enable CUDA graph capture
2023-04-16 00:27:33 -04:00
Kirthi Shankar Sivamani
45567a25a2 only 1 thread writes to global mem in fprop
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-15 06:09:41 +00:00
Kirthi Shankar Sivamani
a0997bc77c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-14 21:45:37 -07:00
Tri Dao
221a39fd3a [Docs] Link to Forbes article 2023-04-14 21:20:38 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Tri Dao
dceb2687c5
Merge pull request #170 from CrustaceanJ/dependencies
Missing module in `setup.py`
2023-04-14 15:41:46 -04:00
Pavel Shvets
72629ac9ba add missed module 2023-04-14 20:08:24 +03:00
Kirthi Shankar Sivamani
081c2b012a
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-13 19:36:45 -07:00
Tri Dao
1c9ef9b399 [Gen] Measure prompt processing + decoding time, not just decoding 2023-04-13 15:39:56 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture 2023-04-12 22:42:24 -07:00
Tri Dao
5cee071431
Merge pull request #164 from ZhiyuanChen/patch-1
make mlp hidden_features defaults to 4*in_features
2023-04-12 23:21:12 -04:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features 2023-04-13 11:08:21 +08:00
Kirthi Shankar Sivamani
31018c5fa0 Support CUDA graph capture
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
853ff72963 Bump version to v1.0.1, fix Cutlass version 2023-04-12 10:05:01 -07:00
Tri Dao
74af023316 Bump version to 1.0.0 2023-04-11 23:32:35 -07:00
Tri Dao
dec4f2e910 [FusedDense] Set workspace size to 32M for Hopper and 4M for others 2023-04-06 23:40:15 -07:00
Tri Dao
d478eeec8f
Merge pull request #154 from kuizhiqing/usage
add paddlepaddle in usage
2023-04-04 02:54:37 -04:00
kuizhiqing
c5be8d3aab add paddlepaddle in usage 2023-04-04 14:15:51 +08:00
Tri Dao
d6fc860573
Merge pull request #147 from ksivaman/add_deterministic_execution_option
Add option for deterministic execution
2023-03-31 17:32:50 -04:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Kirthi Shankar Sivamani
b6aa059bbf Add option for deterministic execution 2023-03-30 18:23:35 -07:00
Tri Dao
009a3e71ec [Training] Fix lightning _PATH import 2023-03-29 01:43:39 -07:00
Tri Dao
993d12448e Implement GPT-NeoX 2023-03-29 01:21:25 -07:00
Tri Dao
f5d0fbd468 [FT] Fix FT's single query attention for bf16 hdim128 rotary 2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00
Tri Dao
4360cfc6a8 [Triton] Fix benchmark_causal.py 2023-03-22 01:34:38 -07:00
Tri Dao
5d079fdd7a [Triton] Fix benchmark_causal, mention Triton version 2023-03-22 00:51:16 -07:00
Tri Dao
dc08ea1c33 Support H100 for other CUDA extensions 2023-03-15 16:59:27 -07:00
Tri Dao
1b18f1b7a1 Support H100 2023-03-15 14:59:02 -07:00
Tri Dao
318e2f1b9b
Merge pull request #140 from VikParuchuri/main
Remove unused kwargs like device in FlashAttention
2023-03-15 17:16:00 -04:00
Vik Paruchuri
3165398074 Remove unused kwargs in flashattention 2023-03-15 10:36:19 -07:00
Tri Dao
e45a46a5b7 [Rotary] Implement GPT-J style (interleaved) rotary 2023-03-14 14:35:53 -07:00
Tri Dao
f28d61cb2a Update README on requirements (nvcc and Pytorch) 2023-03-13 12:48:07 -07:00
Tri Dao
57ee618170
Merge pull request #94 from calebthomas259/main
Add a simple tutorial to README.md
2023-02-14 19:03:08 -08:00
Tri Dao
2dc2a19589 Update roadmap 2023-02-09 12:21:30 -08:00
Tri Dao
06da275bcb
Merge pull request #110 from eltociear/patch-1
fix typo in default.yaml
2023-01-27 12:18:16 -08:00
Tri Dao
6b4a48218e [FA] Remove unused variable rng_engine_inputs 2023-01-25 15:32:40 -08:00
Tri Dao
78b7a1dc18 [OPT] Load fp16 weights on CPU before moving to GPU 2023-01-22 17:01:32 -08:00
Ikko Eltociear Ashimine
419ea45b64
fix typo in default.yaml
additionaly -> additionally
2023-01-21 00:47:12 +09:00
Tri Dao
33e0860c9c Bump to v0.2.8 2023-01-19 13:17:19 -08:00
Tri Dao
eb33e587e9 [LayerNorm] Rename x1 -> residual 2023-01-19 13:07:27 -08:00
Tri Dao
f68d41ec77 [Gen] Add OPT to generation test 2023-01-17 19:59:06 -08:00
Tri Dao
88173a1aaf [FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP 2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb [ViT] Support timm checkpoint, add tests 2023-01-16 01:20:34 -08:00
Tri Dao
2ec7d3f72c
Merge pull request #105 from jamaliki/patch-1
Change default dropout value in documentation
2023-01-15 23:01:20 -08:00
Tri Dao
ef085cfcda [ViT] Fix extra norm_0, use new LN order in Block 2023-01-15 22:58:56 -08:00