Tri Dao
65f723bb9a
Split bwd into more .cu files to speed up compilation
2024-07-23 01:32:09 -07:00
Tri Dao
5ca83a9c71
Clean up softcapping bwd a bit
2024-07-23 00:13:54 -07:00
Tri Dao
751c762c9c
Don't specialize for hdim 224 to speed up compilation
2024-07-23 00:13:54 -07:00
Driss Guessous
1c275eb070
Fix ima for split-kv kernel ( #1085 )
2024-07-22 22:19:46 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping ( #1033 )
...
* check in the two ways of approaching backwards for softcapping, both functional
* prepare the softcap switch for backwards
* temporary
* cleanup to the way Tri prefers
* calculate dtanh when copying from scores -> dtanh Tensor
* no ternary operators allowed for constexpr, so just use some hack found online
* fix maybe_dtanh, restore some files
* restore another file
* move calculate_dtanh to utils and colocate with apply_softcap
* cleanup
* maybe last cleanup
* save for another pr
* remove a stray line
* fix spacing
* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
Jorge António
4df62e1440
catch typo ( #1058 )
2024-07-21 23:24:15 -07:00
Tri Dao
40e534a7f6
Implement cache_leftpad
2024-07-11 08:17:15 -07:00
Tri Dao
dca6d89da4
Don't support softcap and dropout at the same time
...
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
908511b2b6
Split into more .cu files to speed up compilation
2024-07-10 00:24:04 -07:00
Tri Dao
1d536d7de5
Minor cleanup of softcapping
2024-07-09 22:57:03 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. ( #1025 )
...
* Softcap v2 (fwd only).
* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
66RING
9486635c92
Fix typos of comments about shape. ( #837 )
2024-06-30 22:40:59 -07:00
Nicolas Patry
5bf201966a
Fixing argument checking when using seqlenq_ngroups_swapped. ( #976 )
...
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
2024-06-30 22:39:22 -07:00
Liang
ab59ec3590
remove swizzle part of sV.data() to get a completely non-swizzle sVtNoSwizzle ( #984 )
...
Co-authored-by: zl <zl@deepseek.com>
2024-06-30 22:38:44 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout ( #970 )
...
* Support unpadded LSE layout.
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
* Cleanup
* Fix unpadded LSE on split-kv path
* Fix formatting and comments
* Fix inline vs forceinline
---------
Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Tri Dao
d732be1e67
Update to Cutlass 3.5
2024-05-26 12:49:33 -07:00
Tri Dao
656daef4ea
Use Cute's local_tile to get gQ, gK, gV
2024-04-07 20:10:19 -07:00
Tri Dao
9eb3d099c1
Transpose out when swapping seqlen_q and num_groups
2024-04-07 20:10:19 -07:00
Driss Guessous
23e8fa5a26
Add the option for the macro and note ( #893 )
2024-03-27 19:12:11 -07:00
ljss
3e9414f1c3
Minor fix in compute_attn_1rowblock_splitkv ( #900 )
2024-03-27 19:11:45 -07:00
Driss Guessous
4a73e903da
Add in, macrosf for defining __grid_constant__ ( #852 )
2024-03-15 00:48:54 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward ( #831 )
...
* Enable paged attention in varlen forward
* Format + fix padding
2024-03-15 00:48:19 -07:00
Tri Dao
2406f28805
Enable headdim 256 backward on consumer GPUs (Ampere, Ada)
2024-02-21 15:56:19 -08:00
Tri Dao
b32efb1a4d
Don't need to reduce row_sum during online softmax
2024-02-20 13:33:38 -08:00
Tri Dao
d9a5cb291c
Fix dv = torch::empty_like(k) for mha_bwd_varlen as well
2024-02-10 01:03:00 -08:00
Brian Hirsh
2423cca3ad
fix backward for when query and key have different contiguity ( #818 )
2024-02-10 01:01:27 -08:00
Grigory Sizov
4687936413
Fix Windows build ( #816 )
2024-02-07 17:41:53 -08:00
Jeremy Reizenstein
0658e320f6
Preprocessor switches to control functionality ( #788 )
...
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2024-01-29 20:44:23 -08:00
Tri Dao
54e80a3829
Implement page KV cache
...
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
36bc29edf7
Use int64_t instead of uint32_t in kernel_traits.h
2024-01-22 22:39:29 -08:00
Tri Dao
000b67f5d8
Use int64_t instead of uint32_t for index_t
2024-01-22 11:25:50 -08:00
Tri Dao
ea8a25ca38
Remove configure in bwd kernel launch
2024-01-21 15:28:33 -08:00
Grigory Sizov
af01244ddd
Add split-kv and M<->H swap to varlen forward decoding attention ( #754 )
...
* Add split-k, M<->H to varseq path
* skip M<->H when dropout>0, fix LSE
2024-01-21 15:28:36 -08:00
Tri Dao
8f4d82cf5e
Update cutlass to v3.4.0
2024-01-20 22:30:06 -08:00
Tri Dao
395e5a0dba
Move rotary device functions to a separate file
2024-01-20 18:01:18 -08:00
Tri Dao
3e2c827d9a
Remove unused kernel_traits file
2024-01-20 17:41:44 -08:00
Tri Dao
66a127aef8
Refactor masking in fwd pass into 1 object
2024-01-20 17:39:53 -08:00
Tri Dao
ed4959b2eb
Change inline to __forceinline__, use __grid_constant__ param
2024-01-20 17:38:47 -08:00
Tri Dao
6f706eff96
Make Softmax an object
2024-01-19 16:09:31 -08:00
Tri Dao
4ea866ca19
Make Alibi an object
2024-01-15 00:07:11 -08:00
Tri Dao
5aca153d6d
Move bwd preprocess kernels to a separate file
2024-01-14 16:57:03 -08:00
Tri Dao
df1418f9db
Move softmax_rescale_o to softmax.h
2024-01-14 15:06:06 -08:00
Tri Dao
6777336a1c
Move masking to a separate file (mask.h)
2024-01-14 12:43:47 -08:00
Tri Dao
9448264ddd
Remove seqq_parallel backward kernel that's not used
2024-01-14 12:25:49 -08:00
Tri Dao
1274ec3e7e
Move dropout to a separate file (dropout.h)
2024-01-14 12:19:17 -08:00
Tri Dao
10dad61277
apply_dropout now takes tensor of rowcol layout
2024-01-14 01:03:23 -08:00
Tri Dao
d9cbcfb41c
Remove dead code in philox.cuh
2024-01-13 02:02:03 -08:00
Tri Dao
a7b66ae25a
Simplify writing softmax to gmem
2024-01-13 00:25:04 -08:00
Tri Dao
8d1b169ed1
Simplify SmemLayoutVtransposed in kernel_traits.h
2024-01-12 11:53:29 -08:00
Tri Dao
0842ec0da4
Don't dispatch to local if window size >= seqlen_k
2023-12-23 20:59:26 -08:00