flash-attention/csrc/flash_attn_ck/flash_common.cpp
rocking e2182cc21d
Support page kvcache in AMD ROCm (#1198)
* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
2024-09-15 23:17:28 -07:00

35 lines
1.1 KiB
C++

/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#include "flash_common.hpp"
namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
return num_splits;
hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
return num_splits;
// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
return num_splits;
}
} // namespace flash