* base version
* restructure pipelines, add special fp8 epilogue
* add variants
* add fp8 causal and modify dynamic tile scheduler
* better causal schedule
* maintain two schedules for non causal and causal
* removing macros
* fix regression
* clean up unneeded methods and variants
* fix mistake with NumProducerThreads
* base version
* restructure pipelines, add special fp8 epilogue
* add variants
* add fp8 causal and modify dynamic tile scheduler
* better causal schedule
* maintain two schedules for non causal and causal
* removing macros
* fix regression
* clean up unneeded methods and variants
* fix mistake with NumProducerThreads
* use seqlen traits
* add fp8 .cu files and benchmark script
* fix merge issue
* fix merge issue
* fix merge issue
* remove duplicate code
* fix regression with varseqlen
* move varseqlen init in constexpr
* fix test script
* more constexpr on varseqlen and add max offset
* add back test cases
* adding files for fp8 changes.
* removed contiguous check.
* enable all tests except odd-seq-lengths, where it crashes now.
* undid clang formatting.
* change to correct tile size for headdim=128.
* fixed odd-seq-len-k.
* minor formatting.
* minor reformatting.
---------
Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
* Support ck in fmha
* Add ck submodule
* Do not return lse if return_softmax == false
* Use receipt to speed up ck compile time
* Integrate new version of ck_tile
* Support dropout for mha_fwd()
* Add dropout to mha_varlen_fwd()
* Update ck to develop
* Extract padding function for dropout randval
* Extract randval transformation function
* Sync the code structure and coding style with FA
* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py
* fix compile error
* Add mha_bwd
* Generate dropout seed and offset from user generator
* update CK
* Add mha_varlen_bwd
* Use same python as build flash-attn to generate ck kernel
* Fix bug of group mode fwd about returning softmax lse
* larger the test tollerance
* Add test_flash_attn_output() and test_flash_attn_varlen_output()
* Always fill softmax_lse
* Remove duplicate benchmark script, since we already implement mha_bwd
* Refine get value from tuple
* Use default parameter for stream_config
* unblock all platform
* Add comment
* refine the test code
* Refine naming
* Add unpack to namespace
* Do not hardcode the warp size 64
* Add more targets
* Add README
* Optimize mha_fwd if seqlen_q == 1
* Support get_wheel_url for rocm
* Detect rocm environment by pytorch's IS_HIP_EXTENSION
* update to lastest ck
* Add necessary compile flag
* Sync the api with upstream FA
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
* check in the two ways of approaching backwards for softcapping, both functional
* prepare the softcap switch for backwards
* temporary
* cleanup to the way Tri prefers
* calculate dtanh when copying from scores -> dtanh Tensor
* no ternary operators allowed for constexpr, so just use some hack found online
* fix maybe_dtanh, restore some files
* restore another file
* move calculate_dtanh to utils and colocate with apply_softcap
* cleanup
* maybe last cleanup
* save for another pr
* remove a stray line
* fix spacing
* fix an issue, and make test_flash_attn.py ready to test softcapping backwards