Tri Dao
10dad61277
apply_dropout now takes tensor of rowcol layout
2024-01-14 01:03:23 -08:00
Tri Dao
d9cbcfb41c
Remove dead code in philox.cuh
2024-01-13 02:02:03 -08:00
Tri Dao
a7b66ae25a
Simplify writing softmax to gmem
2024-01-13 00:25:04 -08:00
Tri Dao
8d1b169ed1
Simplify SmemLayoutVtransposed in kernel_traits.h
2024-01-12 11:53:29 -08:00
Tri Dao
abbc131173
[LayerNorm] Switch from CUDA to Triton implementation
2024-01-05 00:31:17 -08:00
Tri Dao
0842ec0da4
Don't dispatch to local if window size >= seqlen_k
2023-12-23 20:59:26 -08:00
Tri Dao
732654583c
Implement deterministic backward (thanks to Meituan)
2023-12-23 17:57:36 -08:00
Tri Dao
8448c02889
Update cutlass to v3.3.0
2023-12-21 23:25:50 -08:00
Tri Dao
5ab9b3667b
Clean up alibi, implement non-causal alibi
2023-12-21 22:27:40 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
...
* hard-code alibi in fwd
* use params.h as hun_heads
* hard-code alibi in bwd
* add alibi on/off option
* compute alibi_start, ratio outside of kernels
* fix minor merge conflict
* add test_alibi.py
* change apply_alibi() location before masking
* add alibi in splitkv kernel
* fix backward func # of returns
* add out-of-bound check in apply_alibi()
* update test_alibi.py
* update test_alibi.py for kvcache
* simplify alibi parameter interface
* fix performance issue
by computing alibi outside of branch
* update test_flash_attn_varlen_func() for left padding
* implement alibi_slopes (b, nh) loading
* optimize apply_alibi() a bit
* update test cases for alibi_slopes loading
* reflect stylistic comments
* disable "seqlenq_ngroups_swapped" when using alibi
---------
Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k ( #647 )
...
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
b4bf9cc1f3
Fix performance regression with causal
2023-11-26 19:07:25 -08:00
Tri Dao
db2f80692c
Write zero to out / grad if seqlen_q or seqlen_k is zero
2023-11-19 22:20:01 -08:00
Tri Dao
43bb6d8aaa
Update cutlass to 3.2.2
2023-11-19 21:43:48 -08:00
Driss Guessous
dc4b9ad6c4
add checks ( #640 )
2023-11-19 20:43:27 -08:00
Tri Dao
5a83425442
Change constexpr int to constexpr static int
2023-10-08 16:26:33 -07:00
Tri Dao
e279bf8ed9
[Gen] Accept cache_batch_idx to index into the KV cache
2023-10-03 16:27:26 -07:00
Tri Dao
083e8f525f
Implement local attention
...
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Tri Dao
812cb1c990
Switch cutlass to newer commit to avoid compilation warning
2023-09-24 00:42:50 -07:00
Tri Dao
65c234ed90
Don't over-allocate dq_accum in case of varlen
2023-09-24 00:36:07 -07:00
Tri Dao
1879e089c7
Reduce number of templates for headdim > 128
2023-09-23 22:24:30 -07:00
Tri Dao
2d8ea9a530
Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza)
2023-09-20 23:38:22 -07:00
Tri Dao
dfe29f5e2b
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
2023-09-18 15:29:06 -07:00
Tri Dao
3250ff3d82
Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H)
2023-09-18 14:52:16 -07:00
Tri Dao
43617deab9
Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation
2023-09-18 12:21:36 -07:00
Tri Dao
c984208ddb
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
2023-09-17 16:14:58 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac
[CE] Implement CrossEntropyLoss in Triton
2023-09-15 20:05:28 -07:00
Tri Dao
56b7fc6ee0
Simplify the implementation of KVcache attn by appending KV first
2023-09-13 15:55:48 -07:00
Tri Dao
bb9beb3645
Remove some unused headers
2023-09-12 12:37:10 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
6a89b2f121
Remove constexpr in launch template to fix CI compilation
2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9
Try switching back to Cutlass 3.2.0
2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2
Bump to v2.1.2
2023-09-03 22:23:05 -07:00
Tri Dao
5953c4f58c
Remove unused sdPsum in dot_do_o function
2023-09-03 20:44:07 -07:00
Tri Dao
26d7d92f3d
Fix splitKV combine function when local LSEs are all -inf
2023-09-03 11:39:09 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd ( #512 )
...
* Remove lots of comments
* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h ( #511 )
2023-09-01 15:32:09 -07:00
Aman Gupta Karmani
866a9d33f9
bump cutlass submodule ( #504 )
2023-08-30 10:32:04 -07:00
Tri Dao
31920dda5f
Fix typo with lse_max == -INFINITY
2023-08-29 21:48:59 -07:00
Tri Dao
b1fbbd8337
Implement splitKV attention
2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742
Use generate_kernels.py script from Driss Guessous
2023-08-28 13:34:12 -07:00
dan_the_3rd
c3f2a632aa
[ft_attention] Fix for seqlen=8136 ( #488 )
...
When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error.
`48840 < 48 * 1024` but apparently it's still above the limit somehow..?
Tested on A100
2023-08-28 10:00:22 -07:00
Tri Dao
757058d4d3
Update Cutlass to v3.2.0
2023-08-27 23:47:28 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. ( #436 )
2023-08-24 16:42:34 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
2023-08-16 15:12:36 -07:00
Tri Dao
dbd7923782
Prepare for Cutlass 3.2
2023-08-13 15:24:32 -07:00
Tri Dao
3524e13c11
Update to Cutlass 3.1
2023-08-13 13:53:17 -07:00
Tri Dao
1c41d2b0e5
Fix race condition in bwd (overwriting sK)
2023-08-01 09:00:10 -07:00
Tri Dao
a4f148b6ab
Fix masking of bwd when seqlen is not divisible by 128
2023-07-31 17:46:34 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Joel Lamy-Poirier
767b71ccf0
Fix random state for dropout_layer_norm ( #315 )
2023-07-23 15:05:13 -07:00
Tri Dao
a157cc8c9b
[FT] Implement MQA/GQA
2023-07-22 23:47:01 -07:00
Tri Dao
9ee0ff1d9b
Fix using dO stride for O, which can cause memory error in bwd
2023-07-20 17:39:57 -07:00
Ikko Eltociear Ashimine
dfc60f6b7d
[LayerNorm] Fix typo in ln_api.cpp
...
unintialized -> uninitialized
2023-07-20 01:16:16 +09:00
danthe3rd
538d570c96
Fix compile error on MSVC
...
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
2800efc71f
[FT] rotary_cos/sin should have batch_size dimension
2023-07-06 15:33:33 -07:00
Tri Dao
3a9bfd076f
[FT] rotary_cos/sin should have shape (dim) instead of (seqlen, dim)
2023-07-03 09:41:04 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
Tri Dao
27f8f890df
[FusedDense] Allocate lt_workspace on input device
2023-05-30 14:17:26 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Tri Dao
ad113948a6
[Docs] Clearer error message for bwd d > 64, bump to v1.0.4
2023-04-26 09:19:48 -07:00
Tri Dao
311d6606bf
[Gen] Fix FT kernel smem size, CG when batch size changed
2023-04-20 17:03:13 -07:00
Kirthi Shankar Sivamani
45567a25a2
only 1 thread writes to global mem in fprop
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-15 06:09:41 +00:00
Kirthi Shankar Sivamani
7d25a4ec4f
Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-12 22:42:24 -07:00
Kirthi Shankar Sivamani
31018c5fa0
Support CUDA graph capture
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
dec4f2e910
[FusedDense] Set workspace size to 32M for Hopper and 4M for others
2023-04-06 23:40:15 -07:00
Tri Dao
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
2023-03-31 14:23:45 -07:00
Tri Dao
f5d0fbd468
[FT] Fix FT's single query attention for bf16 hdim128 rotary
2023-03-28 21:27:00 -07:00
Tri Dao
dc08ea1c33
Support H100 for other CUDA extensions
2023-03-15 16:59:27 -07:00
Tri Dao
1b18f1b7a1
Support H100
2023-03-15 14:59:02 -07:00
Tri Dao
e45a46a5b7
[Rotary] Implement GPT-J style (interleaved) rotary
2023-03-14 14:35:53 -07:00
Tri Dao
6b4a48218e
[FA] Remove unused variable rng_engine_inputs
2023-01-25 15:32:40 -08:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
f1e01c27ba
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
2023-01-15 15:20:01 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Tri Dao
6738d9477d
[LayerNorm] Implement RMS Norm
2023-01-06 17:34:22 -08:00
Tri Dao
a1f49a2b92
[Compilation] Change BOOL_SWITCH to fix Windows compilation
...
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
2023-01-06 14:40:58 -08:00
Tri Dao
be1afaa276
[Gen, FT] Use fp32 accum for FMA
2023-01-03 22:09:22 -08:00
Tri Dao
f266fc7262
[Gen, FT] Use tlength instead of params.timestep for rotary
2023-01-03 17:46:55 -08:00
Tri Dao
a01d1213d7
[Gen] Add kernel from FasterTransformer for benchmarking
2023-01-03 17:37:43 -08:00
Tri Dao
a8cfe51551
Implement Tensor Parallel for transformer Block
2022-12-25 14:08:21 -08:00
Tri Dao
1e712ea8b0
Implement TensorParallel for MHA
2022-12-25 11:39:55 -08:00
Tri Dao
226a1b721d
Implement TensorParallel for FusedDense and FusedDenseGeluDense
2022-12-24 11:48:56 -08:00
Tri Dao
dff68c2b22
Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss
2022-12-23 14:51:08 -08:00
Tri Dao
e68ebbe89a
Simplify FusedDense
2022-12-22 21:25:31 -08:00
Tri Dao
5db330519a
[LayerNorm] Support taking subset of input or subset of output
2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a
[LayerNorm] Fuse LayerScale
2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
2022-12-09 02:06:22 -08:00
Tri Dao
8a2ece89f7
Simplify BOOL_SWITCH macro to fix compiling error on gcc 7
2022-12-06 14:38:32 -08:00
Tri Dao
0bf5e50038
Release training code
2022-11-28 17:34:40 -08:00
Tri Dao
9bc63d1e2d
Fix typo in comments
2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d
Speed up compilation by splitting into separate .cu files
2022-11-25 16:30:18 -08:00
Tri Dao
39ed597b28
[LayerNorm] Compile for both sm70 and sm80
2022-11-17 11:45:11 -08:00