flash-attention/tests
rocking e2182cc21d
Support page kvcache in AMD ROCm (#1198)
* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
2024-09-15 23:17:28 -07:00
..
layers Run isort and black on test files 2023-08-18 20:59:35 -07:00
losses [CrossEntropy] Support precomputed LSE 2024-09-08 09:24:43 -07:00
models Add test for BTLM init 2023-12-25 15:16:27 -08:00
modules Run isort and black on test files 2023-08-18 20:59:35 -07:00
ops [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
pyproject.toml Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00
test_flash_attn_ck.py Support page kvcache in AMD ROCm (#1198) 2024-09-15 23:17:28 -07:00
test_flash_attn.py Fix test with alibi and cache_leftpad 2024-07-23 02:04:15 -07:00
test_rotary.py [Rotary] Add test for rotary when qkv are packed an there's GQA 2024-09-12 22:35:20 -07:00
test_util.py add missing if condition for key_padding_mask in test_util.py 2024-08-19 11:17:17 -07:00