* update ck
* update ck
* update ck again
* update ck
* use pointer as seed and offset
* update CK
* Remove useless "else"
* Fix page-attn block table read out-of-bound
---------
Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
* Integrate ck branch of ck_tile/fa_bwd_opt
* Assume dq and q share the same stride
* update ck
* Integrate more stride of dq_acc
* Revert fwd dropout
* Fix paremeter order
* Integrate ck with more stride
* update the limit of hdim of bwd
* Check argument
* Add test_flash_attn_causal
* Support unpad lse
* Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic
* Fix stride and Kn0
* Fix CK sync issue
* Fix typo
* Update CK for changing of fmha_fwd_args
* Add kvcache tmp
* Add kvcache
* Fix comment
* Sync behavior with ck
* Update CK to develop
* remove large test case
* Add kvcache test
* Fix page_block_size in arg
* Minor fix
* Fix stride error
* Update seqlen of kvcache before splitkv
* Fix compile error
* Fix bug of hdim is not 8x
* Fit ck arg
* support adaptive num_splits
* add more tests
* Refine test tolerance
* update CK
* Move override_num_splits_if_necessary into cpp
* update ck
* Update ck
* Support different flag for different version of hip
* remove coerce-illegal, becasue this is not required in FA
* Update ck to fix xcratch memory
* Add coerce-illegal in some version
* Add compile flag for rtn rounding
* remove redundant init
* Using env var to switch rounding mode
* update ck
* Support ck in fmha
* Add ck submodule
* Do not return lse if return_softmax == false
* Use receipt to speed up ck compile time
* Integrate new version of ck_tile
* Support dropout for mha_fwd()
* Add dropout to mha_varlen_fwd()
* Update ck to develop
* Extract padding function for dropout randval
* Extract randval transformation function
* Sync the code structure and coding style with FA
* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py
* fix compile error
* Add mha_bwd
* Generate dropout seed and offset from user generator
* update CK
* Add mha_varlen_bwd
* Use same python as build flash-attn to generate ck kernel
* Fix bug of group mode fwd about returning softmax lse
* larger the test tollerance
* Add test_flash_attn_output() and test_flash_attn_varlen_output()
* Always fill softmax_lse
* Remove duplicate benchmark script, since we already implement mha_bwd
* Refine get value from tuple
* Use default parameter for stream_config
* unblock all platform
* Add comment
* refine the test code
* Refine naming
* Add unpack to namespace
* Do not hardcode the warp size 64
* Add more targets
* Add README
* Optimize mha_fwd if seqlen_q == 1
* Support get_wheel_url for rocm
* Detect rocm environment by pytorch's IS_HIP_EXTENSION
* update to lastest ck
* Add necessary compile flag
* Sync the api with upstream FA
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>