Commit Graph

201 Commits

Author SHA1 Message Date
Tri Dao
65f723bb9a Split bwd into more .cu files to speed up compilation 2024-07-23 01:32:09 -07:00
Tri Dao
5ca83a9c71 Clean up softcapping bwd a bit 2024-07-23 00:13:54 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
Driss Guessous
1c275eb070
Fix ima for split-kv kernel (#1085) 2024-07-22 22:19:46 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
Jorge António
4df62e1440
catch typo (#1058) 2024-07-21 23:24:15 -07:00
Tri Dao
74b0761ff7 [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
Tri Dao
b4a9dd6c9c Temporarily switch to cutlass fork for more shapes 2024-07-11 09:29:21 -07:00
Tri Dao
40e534a7f6 Implement cache_leftpad 2024-07-11 08:17:15 -07:00
Tri Dao
dca6d89da4 Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
908511b2b6 Split into more .cu files to speed up compilation 2024-07-10 00:24:04 -07:00
Tri Dao
1d536d7de5 Minor cleanup of softcapping 2024-07-09 22:57:03 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
66RING
9486635c92
Fix typos of comments about shape. (#837) 2024-06-30 22:40:59 -07:00
Nicolas Patry
5bf201966a
Fixing argument checking when using seqlenq_ngroups_swapped. (#976)
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
2024-06-30 22:39:22 -07:00
Liang
ab59ec3590
remove swizzle part of sV.data() to get a completely non-swizzle sVtNoSwizzle (#984)
Co-authored-by: zl <zl@deepseek.com>
2024-06-30 22:38:44 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout (#970)
* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Tri Dao
d732be1e67 Update to Cutlass 3.5 2024-05-26 12:49:33 -07:00
Tri Dao
656daef4ea Use Cute's local_tile to get gQ, gK, gV 2024-04-07 20:10:19 -07:00
Tri Dao
9eb3d099c1 Transpose out when swapping seqlen_q and num_groups 2024-04-07 20:10:19 -07:00
Driss Guessous
23e8fa5a26
Add the option for the macro and note (#893) 2024-03-27 19:12:11 -07:00
ljss
3e9414f1c3
Minor fix in compute_attn_1rowblock_splitkv (#900) 2024-03-27 19:11:45 -07:00
Driss Guessous
4a73e903da
Add in, macrosf for defining __grid_constant__ (#852) 2024-03-15 00:48:54 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward (#831)
* Enable paged attention in varlen forward

* Format + fix padding
2024-03-15 00:48:19 -07:00
Chirag Jain
50896ec574
Make nvcc threads configurable via environment variable (#885) 2024-03-13 20:46:57 -07:00
Tri Dao
2406f28805 Enable headdim 256 backward on consumer GPUs (Ampere, Ada) 2024-02-21 15:56:19 -08:00
Tri Dao
4d6b794b3c Update Cutlass to v3.4.1 2024-02-20 16:28:21 -08:00
Tri Dao
b32efb1a4d Don't need to reduce row_sum during online softmax 2024-02-20 13:33:38 -08:00
Tri Dao
d9a5cb291c Fix dv = torch::empty_like(k) for mha_bwd_varlen as well 2024-02-10 01:03:00 -08:00
Brian Hirsh
2423cca3ad
fix backward for when query and key have different contiguity (#818) 2024-02-10 01:01:27 -08:00
Grigory Sizov
4687936413
Fix Windows build (#816) 2024-02-07 17:41:53 -08:00
Jeremy Reizenstein
0658e320f6
Preprocessor switches to control functionality (#788)
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2024-01-29 20:44:23 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
36bc29edf7 Use int64_t instead of uint32_t in kernel_traits.h 2024-01-22 22:39:29 -08:00
Tri Dao
000b67f5d8 Use int64_t instead of uint32_t for index_t 2024-01-22 11:25:50 -08:00
Tri Dao
ea8a25ca38 Remove configure in bwd kernel launch 2024-01-21 15:28:33 -08:00
Grigory Sizov
af01244ddd
Add split-kv and M<->H swap to varlen forward decoding attention (#754)
* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
2024-01-21 15:28:36 -08:00
Tri Dao
8f4d82cf5e Update cutlass to v3.4.0 2024-01-20 22:30:06 -08:00
Tri Dao
395e5a0dba Move rotary device functions to a separate file 2024-01-20 18:01:18 -08:00
Tri Dao
3e2c827d9a Remove unused kernel_traits file 2024-01-20 17:41:44 -08:00
Tri Dao
66a127aef8 Refactor masking in fwd pass into 1 object 2024-01-20 17:39:53 -08:00
Tri Dao
ed4959b2eb Change inline to __forceinline__, use __grid_constant__ param 2024-01-20 17:38:47 -08:00
Tri Dao
6f706eff96 Make Softmax an object 2024-01-19 16:09:31 -08:00
Tri Dao
4ea866ca19 Make Alibi an object 2024-01-15 00:07:11 -08:00
Tri Dao
5aca153d6d Move bwd preprocess kernels to a separate file 2024-01-14 16:57:03 -08:00
Tri Dao
df1418f9db Move softmax_rescale_o to softmax.h 2024-01-14 15:06:06 -08:00
Tri Dao
6777336a1c Move masking to a separate file (mask.h) 2024-01-14 12:43:47 -08:00
Tri Dao
9448264ddd Remove seqq_parallel backward kernel that's not used 2024-01-14 12:25:49 -08:00
Tri Dao
1274ec3e7e Move dropout to a separate file (dropout.h) 2024-01-14 12:19:17 -08:00