Commit Graph

683 Commits

Author SHA1 Message Date
juejuezi
e371bea04f
feat: change minimal supported CUDA version to 11.7 (#1206) 2024-09-05 10:34:35 -07:00
Cameron Shinn
3cea2fb6ee
Add ArchTag to pre/postprocess bwd kernels (#1180)
* Add ArchTag to pre/postprocess bwd kernels

* Type-dependent CC check for bwd pre/postprocess

* Fix CC >= 90 for bwd postprocess

---------

Co-authored-by: Cameron Shinn <cshinn@nvidia.com>
2024-08-28 00:20:47 -07:00
jayhshah
c92ca63268
FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
Tri Dao
d79f9b41a8 [CrossEntropy] Use online softmax to simplify implementation 2024-08-24 17:40:39 -07:00
Jay Shah
32792d37ec add missing if condition for key_padding_mask in test_util.py 2024-08-19 11:17:17 -07:00
Ying Zhang
28e7f4ddbd
Merge pull request #1155 from ipiszy/fix
Fix out-of-bound writes for var-seq-len zero-length KVs
2024-08-17 13:34:06 -07:00
Ying Zhang
53537da422 add a unittest 2024-08-17 13:23:50 -07:00
Ying Zhang
a3a257c71d Fix out-of-bound writes for var-seq-len zero-length KVs 2024-08-16 01:17:40 -07:00
Tri Dao
bcd918f275 [LayerNorm] Add option to write result to out and residual_out 2024-08-15 14:43:47 -07:00
Tri Dao
bd82d6c6eb Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
This reverts commit 800401847e.
2024-08-15 12:02:39 -07:00
Tri Dao
800401847e [LayerNorm] Don't store x + residual if we don't need gradients 2024-08-15 11:08:46 -07:00
Garrett Byrd
16025d8cc9
Clearer install instructions for CUDA and ROCm backends (#1147)
* Update README.md

* Update README.md

* Update README.md (Added missing line in AMD ROCm Support)
2024-08-13 22:21:56 -07:00
Ying Zhang
3669b25206
bwd benchmark + small fixes (#1129) 2024-08-05 21:27:52 -07:00
Tri Dao
5d5bfbb619 Remove contiguous checks 2024-08-05 14:47:07 -07:00
SueJane
3f1b4d38e7
Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127) 2024-08-05 08:59:23 -07:00
Tri Dao
3f6ff1c1c5 Remove struct : cute::aligned_struct to avoid error with gcc 12 2024-08-02 00:59:35 -07:00
Tri Dao
c33de664a1 Fix import in test 2024-08-01 02:14:25 -07:00
Tri Dao
bafe253042 [FA3] Bwd 2024-08-01 01:57:06 -07:00
Ying Zhang
abffb0f98c
Merge pull request #1115 from ipiszy/bench
Add cudnn benchmark for var-len
2024-07-31 22:42:06 -07:00
Ying Zhang
c7f20a2d31 add cudnn benchmark for var-len 2024-07-31 22:33:29 -07:00
jayhshah
5018ac6ac5
Fp8 kernel with "in-kernel" transpose of V in producer (#1100)
* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* use seqlen traits

* add fp8 .cu files and benchmark script

* fix merge issue

* fix merge issue

* fix merge issue

* remove duplicate code

* fix regression with varseqlen

* move varseqlen init in constexpr

* fix test script

* more constexpr on varseqlen and add max offset

* add back test cases
2024-07-30 14:14:14 -07:00
Tri Dao
c4b9015d74 Add benchmark_gemm.py 2024-07-27 11:13:18 -07:00
Tri Dao
418d677192 Bump to v2.6.3 2024-07-25 01:31:28 -07:00
Tri Dao
65205d350e [CI] Compile for pytorch 2.4.0 2024-07-25 01:30:34 -07:00
Tri Dao
3aae9c18c1 Revert "Changes For FP8 (#1075)"
This reverts commit 1899c970c8.
2024-07-25 01:28:44 -07:00
ganeshcolfax
1899c970c8
Changes For FP8 (#1075)
* adding files for fp8 changes.

* removed contiguous check.

* enable all tests except odd-seq-lengths, where it crashes now.

* undid clang formatting.

* change to correct tile size for headdim=128.

* fixed odd-seq-len-k.

* minor formatting.

* minor reformatting.

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
2024-07-23 13:51:14 -07:00
Tri Dao
59594f2a67 Bump to v2.6.2 2024-07-23 02:30:05 -07:00
Tri Dao
299563626f Fix test with alibi and cache_leftpad 2024-07-23 02:04:15 -07:00
Tri Dao
4488acee8d [CI] Compile with torch 2.4.0.dev20240527 2024-07-23 01:33:32 -07:00
Tri Dao
65f723bb9a Split bwd into more .cu files to speed up compilation 2024-07-23 01:32:09 -07:00
Tri Dao
5ca83a9c71 Clean up softcapping bwd a bit 2024-07-23 00:13:54 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
Driss Guessous
1c275eb070
Fix ima for split-kv kernel (#1085) 2024-07-22 22:19:46 -07:00
janEbert
3c4053b75c
Make FA3 externally importable (#1053)
Library name to import is `flash_attn_interface`, which matches the
test.
2024-07-22 21:34:56 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Ying Zhang
dfe1a59e4b
Add var-seq-len to FA3 fp16 / bf16 fwd (#1072)
* fwd var-seq-len

* fixes

* benchmark

* fixes

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
2024-07-22 21:32:41 -07:00
Cameron Shinn
cb516f855b
Remove torchlib dependency from cpp files (#1083) 2024-07-22 16:47:09 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
youkaichao
ef3e358a25
remove lambda (#1056) 2024-07-21 23:24:38 -07:00
Jorge António
4df62e1440
catch typo (#1058) 2024-07-21 23:24:15 -07:00
Tri Dao
74b0761ff7 [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
Tri Dao
898dd4bbf2 Pass seqused_k to _flash_attn_varlen_forward 2024-07-13 00:08:27 -07:00
Tri Dao
7ef24848cf Add FA3 image 2024-07-11 09:54:05 -07:00
Tri Dao
7f67966cc7 FA3 initial code release 2024-07-11 09:53:36 -07:00
Tri Dao
b4a9dd6c9c Temporarily switch to cutlass fork for more shapes 2024-07-11 09:29:21 -07:00
Tri Dao
7551202cb2 Bump to v2.6.1 2024-07-11 08:28:32 -07:00
Tri Dao
844912dca0 [CI] Switch from CUDA 12.2 to 12.3 2024-07-11 08:20:09 -07:00
Tri Dao
40e534a7f6 Implement cache_leftpad 2024-07-11 08:17:15 -07:00
Tri Dao
116b05f9b0 [CI] Compile with pytorch 2.4.0.dev20240514 2024-07-11 02:53:30 -07:00
Tri Dao
da11d1b853 Bump v2.6.0 2024-07-10 21:34:58 -07:00