35 lines
1.1 KiB
C++
35 lines
1.1 KiB
C++
|
|
/******************************************************************************
|
||
|
|
* Copyright (c) 2024, Tri Dao.
|
||
|
|
******************************************************************************/
|
||
|
|
|
||
|
|
#include "flash_common.hpp"
|
||
|
|
|
||
|
|
namespace flash {
|
||
|
|
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||
|
|
{
|
||
|
|
int device;
|
||
|
|
auto status = hipGetDevice(&device);
|
||
|
|
if(status != hipSuccess)
|
||
|
|
return num_splits;
|
||
|
|
|
||
|
|
hipDeviceProp_t props{};
|
||
|
|
status = hipGetDeviceProperties(&props, device);
|
||
|
|
if(status != hipSuccess)
|
||
|
|
return num_splits;
|
||
|
|
|
||
|
|
// TODO - tile size should match the TileFmhaShape, hardcode for now
|
||
|
|
const int kM0 = 128;
|
||
|
|
const int kN1 = hdim_v;
|
||
|
|
|
||
|
|
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||
|
|
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
|
||
|
|
|
||
|
|
if(num_splits < 1 && p_drop == 0.0f)
|
||
|
|
return num_splits_heuristic_ck(
|
||
|
|
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||
|
|
|
||
|
|
return num_splits;
|
||
|
|
}
|
||
|
|
|
||
|
|
} // namespace flash
|