Tri Dao
65c234ed90
Don't over-allocate dq_accum in case of varlen
2023-09-24 00:36:07 -07:00
Tri Dao
1879e089c7
Reduce number of templates for headdim > 128
2023-09-23 22:24:30 -07:00
Tri Dao
2d8ea9a530
Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza)
2023-09-20 23:38:22 -07:00
Tri Dao
43617deab9
Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation
2023-09-18 12:21:36 -07:00
Tri Dao
c984208ddb
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
2023-09-17 16:14:58 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
56b7fc6ee0
Simplify the implementation of KVcache attn by appending KV first
2023-09-13 15:55:48 -07:00
Tri Dao
bb9beb3645
Remove some unused headers
2023-09-12 12:37:10 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
6a89b2f121
Remove constexpr in launch template to fix CI compilation
2023-09-03 22:59:41 -07:00
Tri Dao
1dc1b6c8f2
Bump to v2.1.2
2023-09-03 22:23:05 -07:00
Tri Dao
5953c4f58c
Remove unused sdPsum in dot_do_o function
2023-09-03 20:44:07 -07:00
Tri Dao
26d7d92f3d
Fix splitKV combine function when local LSEs are all -inf
2023-09-03 11:39:09 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd ( #512 )
...
* Remove lots of comments
* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h ( #511 )
2023-09-01 15:32:09 -07:00
Tri Dao
31920dda5f
Fix typo with lse_max == -INFINITY
2023-08-29 21:48:59 -07:00
Tri Dao
b1fbbd8337
Implement splitKV attention
2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742
Use generate_kernels.py script from Driss Guessous
2023-08-28 13:34:12 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. ( #436 )
2023-08-24 16:42:34 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
2023-08-16 15:12:36 -07:00
Tri Dao
dbd7923782
Prepare for Cutlass 3.2
2023-08-13 15:24:32 -07:00
Tri Dao
3524e13c11
Update to Cutlass 3.1
2023-08-13 13:53:17 -07:00
Tri Dao
1c41d2b0e5
Fix race condition in bwd (overwriting sK)
2023-08-01 09:00:10 -07:00
Tri Dao
a4f148b6ab
Fix masking of bwd when seqlen is not divisible by 128
2023-07-31 17:46:34 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
9ee0ff1d9b
Fix using dO stride for O, which can cause memory error in bwd
2023-07-20 17:39:57 -07:00
danthe3rd
538d570c96
Fix compile error on MSVC
...
See also: https://stackoverflow.com/questions/55136414/constexpr-variable-captured-inside-lambda-loses-its-constexpr-ness
2023-07-19 08:04:57 +00:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Kirthi Shankar Sivamani
45567a25a2
only 1 thread writes to global mem in fprop
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-15 06:09:41 +00:00
Kirthi Shankar Sivamani
31018c5fa0
Support CUDA graph capture
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
1b18f1b7a1
Support H100
2023-03-15 14:59:02 -07:00
Tri Dao
a1f49a2b92
[Compilation] Change BOOL_SWITCH to fix Windows compilation
...
Follow xFormers's DISTPATCH_BOOL. Haven't tested it on Windows.
2023-01-06 14:40:58 -08:00
Tri Dao
8a2ece89f7
Simplify BOOL_SWITCH macro to fix compiling error on gcc 7
2022-12-06 14:38:32 -08:00
Tri Dao
9bc63d1e2d
Fix typo in comments
2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d
Speed up compilation by splitting into separate .cu files
2022-11-25 16:30:18 -08:00
Tri Dao
6998e0ecdb
Fix out-of-bound memory read
2022-11-09 09:34:14 -08:00
Tri Dao
557781933d
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
...
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
c422fee377
Get rid of o_rows_are_valid since we don't have headdim=16 anymore
2022-10-24 17:29:36 -07:00
Tri Dao
46fd2a20b2
Support all head dims that are multiples of 8, up to 128
2022-10-24 16:04:21 -07:00
Tri Dao
9e92a1f2d2
Attempt to use atomicCAS to replace atomicAdd(bfloat16)
2022-10-23 16:22:43 -07:00
Tri Dao
a5a8806d1a
Split bwd on the seqlen_q dimension
2022-10-23 11:35:15 -07:00
Tri Dao
871db47941
Don't need to run configure for the forward pass
2022-10-21 18:22:27 -07:00
Tri Dao
7fc39832e2
Use block_size=128 for headdim=128 on SM80
...
Previously we were using block_size=256.
2022-10-21 13:19:54 -07:00
Tri Dao
a44f48df5a
Split fwd on the seqlen_q dimension
2022-10-21 12:04:27 -07:00
Tri Dao
1aa6d7d9b6
Rework dropout to decouple forward and backward
...
They don't have to have the same block size, number of threads, etc.
2022-10-21 12:04:27 -07:00
Tri Dao
5badfb7848
Implement attention kernel that splits the batch into two
2022-10-13 20:49:02 -07:00
Eric Engelhart
2211db5fab
Fixed switch statement, thanks @yocabon
2022-10-04 21:31:39 -04:00
Eric Engelhart
9d7fd5b6e7
Replace BOOL_SWITCH with FP16_SWITCH to work around MSVC bug with constexpr variables and templates
2022-10-04 21:31:39 -04:00
Tri Dao
8166063a55
Use block_size=128 for d=128 on SM86 to avoid exceeding smem limit
2022-09-12 14:21:29 -07:00
Tri Dao
bc2c210254
Don't nest BOOL_SWITCH to work around gcc 7 bug
2022-07-11 10:28:46 -07:00
Tri Dao
de19de7ab1
Implement for bf16
2022-07-09 23:31:56 -07:00
Tri Dao
6a77a6da10
Refactor gemm_cl to template on either __half or __nv_bfloat16
2022-07-09 23:18:26 -07:00
Tri Dao
e518a4b327
Refactor to template on __half, implement bf16 util functions
2022-07-09 23:18:26 -07:00
Tri Dao
2dc1b205f6
Fix Illegal Memory Access bug in fwd when d=16
2022-07-09 23:17:14 -07:00
Tri Dao
5b838a8bef
Apply dropout scaling to dQ and dK instead of to V (in bwd)
...
Theoretically this might have lower numerical error since the scaling is in
fp32 instead of fp16 (not sure, I haven't thought too carefully about it).
However, in practice, the numerical errors seem about the same.
2022-07-03 17:53:37 -07:00
Tri Dao
a5559a0e75
Do P * dP (pointwise) in the bwd in fp32 instead of fp16
2022-07-03 17:52:05 -07:00
Tri Dao
6c3a8c65af
Implement cross attention
2022-07-03 17:48:12 -07:00
Tri Dao
f66603cb6f
Support batch size > 64K by swapping grid.x and grid.y
2022-06-29 23:16:24 -07:00
Tri Dao
ea38d3d261
Fix race condition in backward pass (smem_dq)
2022-06-25 18:02:30 -07:00
Tri Dao
eeca63a72a
Bug fix: wrong smem_o write pointer for d=16
2022-06-25 15:18:33 -07:00
Tri Dao
5d07483bbc
Refactor Gmem code to store q, k, v pointers separately
2022-06-12 16:37:32 -07:00
Tri Dao
d3e6440958
Implement bwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
0d854692c6
Implement fwd for head dim 128
2022-06-11 17:52:36 -07:00
Tri Dao
321c57d07d
Set block size of SM75 fwd to 256 if there's no dropout
...
This speeds up the fwd by 1.5x.
2022-06-04 16:51:28 -07:00
Tri Dao
d380e87fb6
Don't use Smem_dp_sum in backward pass
...
To reduce smem usage for SM75
2022-06-04 16:01:36 -07:00
Tri Dao
b17c6fe235
Reduce smem usage for Q and dO in the backward pass
...
From 4KB per buffer to 2KB per buffer. This saves us 8KB of smem (each Q and dO
have 2 buffers)
2022-06-03 16:59:11 -07:00
Tri Dao
2712aa4c8d
Support Turing mma instructions
2022-06-03 16:58:44 -07:00
Tri Dao
050873327e
Remove softmax fp16 max
2022-06-02 14:09:46 -07:00
Tri Dao
14dc326e59
Use Cutlass gemm as WarpMma
2022-06-02 10:33:32 -07:00
Tri Dao
e78e7c9553
Remove old backward
2022-06-02 10:13:44 -07:00
Tri Dao
c41479d66d
Support SM86 GPUs
2022-06-01 18:49:47 -07:00
Tri Dao
9dbc491aa5
Rename, add benchmarking script
2022-05-26 13:57:38 -07:00