* Integrate ck branch of ck_tile/fa_bwd_opt * Assume dq and q share the same stride * update ck * Integrate more stride of dq_acc * Revert fwd dropout * Fix paremeter order * Integrate ck with more stride * update the limit of hdim of bwd * Check argument * Add test_flash_attn_causal * Support unpad lse * Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic * Fix stride and Kn0 * Fix CK sync issue * Fix typo * Update CK for changing of fmha_fwd_args * Add kvcache tmp * Add kvcache * Fix comment * Sync behavior with ck * Update CK to develop * remove large test case * Add kvcache test * Fix page_block_size in arg * Minor fix * Fix stride error * Update seqlen of kvcache before splitkv * Fix compile error * Fix bug of hdim is not 8x * Fit ck arg * support adaptive num_splits * add more tests * Refine test tolerance * update CK * Move override_num_splits_if_necessary into cpp * update ck * Update ck * Support different flag for different version of hip * remove coerce-illegal, becasue this is not required in FA * Update ck to fix xcratch memory * Add coerce-illegal in some version * Add compile flag for rtn rounding * remove redundant init * Using env var to switch rounding mode * update ck
35 lines
1.1 KiB
C++
35 lines
1.1 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2024, Tri Dao.
|
|
******************************************************************************/
|
|
|
|
#include "flash_common.hpp"
|
|
|
|
namespace flash {
|
|
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
|
{
|
|
int device;
|
|
auto status = hipGetDevice(&device);
|
|
if(status != hipSuccess)
|
|
return num_splits;
|
|
|
|
hipDeviceProp_t props{};
|
|
status = hipGetDeviceProperties(&props, device);
|
|
if(status != hipSuccess)
|
|
return num_splits;
|
|
|
|
// TODO - tile size should match the TileFmhaShape, hardcode for now
|
|
const int kM0 = 128;
|
|
const int kN1 = hdim_v;
|
|
|
|
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
|
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
|
|
|
|
if(num_splits < 1 && p_drop == 0.0f)
|
|
return num_splits_heuristic_ck(
|
|
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
|
|
|
return num_splits;
|
|
}
|
|
|
|
} // namespace flash
|