* update ck
* update ck
* update ck again
* update ck
* use pointer as seed and offset
* update CK
* Remove useless "else"
* Fix page-attn block table read out-of-bound
---------
Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
* Add custom ops for compatibility with PT Compile
* Add support for varlen functions too
* Add version checks for pytorch API
* Fix PT compile interfaces so it works e2e
* Make sure PT < 2.4 runs fine
* Fix python mistake
* Fix all the autograd magic issues
* typo on head_dim
* Fix deterministic test failures, remove unneeded detaches()
* remove test requires_grad
* Resolve all the pytorch versioning issues
* C++ and python refactor to improve padding management for torch.compile()
* Add improvements suggested by @anijain2305
* Integrate ck branch of ck_tile/fa_bwd_opt
* Assume dq and q share the same stride
* update ck
* Integrate more stride of dq_acc
* Revert fwd dropout
* Fix paremeter order
* Integrate ck with more stride
* update the limit of hdim of bwd
* Check argument
* Add test_flash_attn_causal
* Support unpad lse
* Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic
* Fix stride and Kn0
* Fix CK sync issue
* Fix typo
* Update CK for changing of fmha_fwd_args
* Add kvcache tmp
* Add kvcache
* Fix comment
* Sync behavior with ck
* Update CK to develop
* remove large test case
* Add kvcache test
* Fix page_block_size in arg
* Minor fix
* Fix stride error
* Update seqlen of kvcache before splitkv
* Fix compile error
* Fix bug of hdim is not 8x
* Fit ck arg
* support adaptive num_splits
* add more tests
* Refine test tolerance
* update CK
* Move override_num_splits_if_necessary into cpp
* update ck
* Update ck
* Support different flag for different version of hip
* remove coerce-illegal, becasue this is not required in FA
* Update ck to fix xcratch memory
* Add coerce-illegal in some version
* Add compile flag for rtn rounding
* remove redundant init
* Using env var to switch rounding mode
* update ck
* Support ck in fmha
* Add ck submodule
* Do not return lse if return_softmax == false
* Use receipt to speed up ck compile time
* Integrate new version of ck_tile
* Support dropout for mha_fwd()
* Add dropout to mha_varlen_fwd()
* Update ck to develop
* Extract padding function for dropout randval
* Extract randval transformation function
* Sync the code structure and coding style with FA
* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py
* fix compile error
* Add mha_bwd
* Generate dropout seed and offset from user generator
* update CK
* Add mha_varlen_bwd
* Use same python as build flash-attn to generate ck kernel
* Fix bug of group mode fwd about returning softmax lse
* larger the test tollerance
* Add test_flash_attn_output() and test_flash_attn_varlen_output()
* Always fill softmax_lse
* Remove duplicate benchmark script, since we already implement mha_bwd
* Refine get value from tuple
* Use default parameter for stream_config
* unblock all platform
* Add comment
* refine the test code
* Refine naming
* Add unpack to namespace
* Do not hardcode the warp size 64
* Add more targets
* Add README
* Optimize mha_fwd if seqlen_q == 1
* Support get_wheel_url for rocm
* Detect rocm environment by pytorch's IS_HIP_EXTENSION
* update to lastest ck
* Add necessary compile flag
* Sync the api with upstream FA
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
* check in the two ways of approaching backwards for softcapping, both functional
* prepare the softcap switch for backwards
* temporary
* cleanup to the way Tri prefers
* calculate dtanh when copying from scores -> dtanh Tensor
* no ternary operators allowed for constexpr, so just use some hack found online
* fix maybe_dtanh, restore some files
* restore another file
* move calculate_dtanh to utils and colocate with apply_softcap
* cleanup
* maybe last cleanup
* save for another pr
* remove a stray line
* fix spacing
* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>