improve streamk load balance (#743)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
78b30d3191
commit
1e64f153b3
@ -107,6 +107,8 @@ protected:
|
||||
/// Kernel SM occupancy (in thread blocks)
|
||||
thread_local static int sm_occupancy_;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
thread_local static int smem_size_;
|
||||
|
||||
/// Initialize static thread-local members for the thread's current device,
|
||||
/// if necessary.
|
||||
@ -138,15 +140,15 @@ protected:
|
||||
}
|
||||
|
||||
// Update the kernel function's shared memory configuration for the current device
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
// Requires more than 48KB: configure for extended, dynamic shared memory
|
||||
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
// If requires more than 48KB: configure for extended, dynamic shared memory
|
||||
if (smem_size_ >= (48 << 10))
|
||||
{
|
||||
cudart_result = cudaFuncSetAttribute(
|
||||
Kernel2<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
smem_size_);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||
return Status::kErrorInternal;
|
||||
@ -166,7 +168,7 @@ protected:
|
||||
&sm_occupancy_,
|
||||
Kernel2<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
int(sizeof(typename GemmKernel::SharedStorage)),
|
||||
smem_size_,
|
||||
cudaOccupancyDisableCachingOverride);
|
||||
if (cudart_result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
|
||||
@ -179,7 +181,9 @@ protected:
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
"device_ordinal: (" << device_ordinal_ << "), "
|
||||
"device_sms: (" << device_sms_ << "), "
|
||||
"sm_occupancy: (" << sm_occupancy_ << ")");
|
||||
"sm_occupancy: (" << sm_occupancy_ << ") "
|
||||
"smem_size: (" << smem_size_ << ") "
|
||||
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
@ -335,7 +339,6 @@ public:
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||
|
||||
// Configure grid and block dimensions
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||
dim3 grid = params_.get_grid_dims();
|
||||
|
||||
@ -343,9 +346,9 @@ public:
|
||||
CUTLASS_TRACE_HOST(" "
|
||||
"grid: (" << grid << "), "
|
||||
"block: (" << block << "), "
|
||||
"SMEM: (" << smem_size << ")");
|
||||
"SMEM: (" << smem_size_ << ")");
|
||||
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||||
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
|
||||
|
||||
// Query for errors
|
||||
cudaError_t result = cudaGetLastError();
|
||||
@ -398,6 +401,11 @@ thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
|
||||
template <typename GemmKernel_>
|
||||
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
|
||||
|
||||
/// Kernel dynamic shared memory allocation requirement
|
||||
template <typename GemmKernel_>
|
||||
thread_local int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
|
||||
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
|
||||
FastDivmod sk_iters_per_big_block;
|
||||
FastDivmod sk_iters_per_region;
|
||||
FastDivmod sk_blocks_per_region;
|
||||
FastDivmod sm_occupancy;
|
||||
|
||||
} div_mod;
|
||||
|
||||
|
||||
@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
", dp_blocks: " << dp_blocks <<
|
||||
", sk_blocks_per_region: " << sk_blocks_per_region <<
|
||||
", sk_regions: " << sk_regions <<
|
||||
", sk_waves: " << sk_waves <<
|
||||
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
|
||||
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
||||
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
||||
@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
", sm_occupancy: " << sm_occupancy <<
|
||||
", avail_sms: " << avail_sms <<
|
||||
", cohort_raster: " << cohort_raster <<
|
||||
", num_blocks: " << get_num_blocks() <<
|
||||
"\n\n";
|
||||
#endif
|
||||
}
|
||||
@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {
|
||||
|
||||
// We're at (or greater) than GPU occupancy
|
||||
|
||||
if (full_waves % sm_occupancy == sm_occupancy - 1)
|
||||
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
|
||||
{
|
||||
// Form the SK wave from the partial wave to get us to full GPU occupancy
|
||||
// If occupancy is more than one CTA per SM, form the SK wave from the partial
|
||||
// wave to get us to full GPU occupancy
|
||||
int max_sk_occupancy = 1;
|
||||
|
||||
dp_tiles = full_wave_tiles;
|
||||
@ -533,7 +534,6 @@ struct ThreadblockSwizzleStreamK {
|
||||
dp_first_wave_tiles += waveset_excess;
|
||||
dp_blocks -= (waveset_excess * avail_sms);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Setup fast-div/mod for device-side usage
|
||||
@ -541,7 +541,6 @@ struct ThreadblockSwizzleStreamK {
|
||||
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
|
||||
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
||||
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
|
||||
div_mod.sm_occupancy = FastDivmod(sm_occupancy);
|
||||
}
|
||||
|
||||
|
||||
@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
|
||||
/// Obtains number of threadblocks per GEMM
|
||||
int get_num_blocks() const
|
||||
{
|
||||
// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
|
||||
// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;
|
||||
|
||||
|
||||
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
||||
|
||||
if (work_blocks < avail_sms)
|
||||
if (work_blocks <= avail_sms * 2)
|
||||
{
|
||||
return work_blocks;
|
||||
}
|
||||
|
||||
int gpu_occupancy = sm_occupancy * avail_sms;
|
||||
int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy;
|
||||
return gpu_wavesets * gpu_occupancy;
|
||||
|
||||
return fast_max(work_blocks, avail_sms * 4);
|
||||
}
|
||||
|
||||
|
||||
@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
|
||||
CUTLASS_DEVICE
|
||||
int get_block_idx() const
|
||||
{
|
||||
// Remap the block indices for the first two waves of thread blocks if
|
||||
// we have multi-occupancy and the grid constitutes four or more waves
|
||||
|
||||
int block_idx = RematerializeBlockIdxX();
|
||||
|
||||
int gpu_occupancy = avail_sms * sm_occupancy;
|
||||
int num_blocks = device_num_blocks();
|
||||
int dest_sm, dest_wave;
|
||||
|
||||
div_mod.sm_occupancy(dest_sm, dest_wave, block_idx);
|
||||
|
||||
int dest_sm = block_idx / 2;
|
||||
int dest_wave = block_idx % 2;
|
||||
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
||||
|
||||
// remapping the first gpu_occupancy blocks
|
||||
if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
|
||||
if ((sm_occupancy > 1) &&
|
||||
(num_blocks >= avail_sms * 4) &&
|
||||
(block_idx < avail_sms * 2))
|
||||
{
|
||||
block_idx = remapped_block_idx;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user