* Integrate ck branch of ck_tile/fa_bwd_opt
* Assume dq and q share the same stride
* update ck
* Integrate more stride of dq_acc
* Revert fwd dropout
* Fix paremeter order
* Integrate ck with more stride
* update the limit of hdim of bwd
* Check argument
* Add test_flash_attn_causal
* Support unpad lse
* Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic
* Fix stride and Kn0
* Fix CK sync issue
* Fix typo
* Update CK for changing of fmha_fwd_args
* Add kvcache tmp
* Add kvcache
* Fix comment
* Sync behavior with ck
* Update CK to develop
* remove large test case
* Add kvcache test
* Fix page_block_size in arg
* Minor fix
* Fix stride error
* Update seqlen of kvcache before splitkv
* Fix compile error
* Fix bug of hdim is not 8x
* Fit ck arg
* support adaptive num_splits
* add more tests
* Refine test tolerance
* update CK
* Move override_num_splits_if_necessary into cpp
* update ck
* Update ck
* Support different flag for different version of hip
* remove coerce-illegal, becasue this is not required in FA
* Update ck to fix xcratch memory
* Add coerce-illegal in some version
* Add compile flag for rtn rounding
* remove redundant init
* Using env var to switch rounding mode
* update ck
* Support ck in fmha
* Add ck submodule
* Do not return lse if return_softmax == false
* Use receipt to speed up ck compile time
* Integrate new version of ck_tile
* Support dropout for mha_fwd()
* Add dropout to mha_varlen_fwd()
* Update ck to develop
* Extract padding function for dropout randval
* Extract randval transformation function
* Sync the code structure and coding style with FA
* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py
* fix compile error
* Add mha_bwd
* Generate dropout seed and offset from user generator
* update CK
* Add mha_varlen_bwd
* Use same python as build flash-attn to generate ck kernel
* Fix bug of group mode fwd about returning softmax lse
* larger the test tollerance
* Add test_flash_attn_output() and test_flash_attn_varlen_output()
* Always fill softmax_lse
* Remove duplicate benchmark script, since we already implement mha_bwd
* Refine get value from tuple
* Use default parameter for stream_config
* unblock all platform
* Add comment
* refine the test code
* Refine naming
* Add unpack to namespace
* Do not hardcode the warp size 64
* Add more targets
* Add README
* Optimize mha_fwd if seqlen_q == 1
* Support get_wheel_url for rocm
* Detect rocm environment by pytorch's IS_HIP_EXTENSION
* update to lastest ck
* Add necessary compile flag
* Sync the api with upstream FA
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
* 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention: (25 commits)
Install standard non-wheel package
Remove release creation
Build wheel on each push
Isolate 2.0.0 & cuda12
Clean setup.py imports
Remove builder project
Bump version
Add notes to github action workflow
Add torch dependency to final build
Exclude cuda erroring builds
Exclude additional disallowed matrix params
Full version matrix
Add CUDA 11.7
Release is actually unsupported
echo OS version
Temp disable deploy
OS version build numbers
Restore full build matrix
Refactor and clean of setup.py
Strip cuda name from torch version
...