improve streamk load balance (#743)

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Haicheng Wu 2022-12-25 13:56:33 -05:00 committed by GitHub
parent 78b30d3191
commit 1e64f153b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 33 deletions

View File

@ -107,6 +107,8 @@ protected:
/// Kernel SM occupancy (in thread blocks)
thread_local static int sm_occupancy_;
/// Kernel dynamic shared memory allocation requirement
thread_local static int smem_size_;
/// Initialize static thread-local members for the thread's current device,
/// if necessary.
@ -138,15 +140,15 @@ protected:
}
// Update the kernel function's shared memory configuration for the current device
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
if (smem_size >= (48 << 10))
{
// Requires more than 48KB: configure for extended, dynamic shared memory
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
// If requires more than 48KB: configure for extended, dynamic shared memory
if (smem_size_ >= (48 << 10))
{
cudart_result = cudaFuncSetAttribute(
Kernel2<GemmKernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
smem_size_);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
return Status::kErrorInternal;
@ -166,7 +168,7 @@ protected:
&sm_occupancy_,
Kernel2<GemmKernel>,
GemmKernel::kThreadCount,
int(sizeof(typename GemmKernel::SharedStorage)),
smem_size_,
cudaOccupancyDisableCachingOverride);
if (cudart_result != cudaSuccess) {
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
@ -179,7 +181,9 @@ protected:
CUTLASS_TRACE_HOST(" "
"device_ordinal: (" << device_ordinal_ << "), "
"device_sms: (" << device_sms_ << "), "
"sm_occupancy: (" << sm_occupancy_ << ")");
"sm_occupancy: (" << sm_occupancy_ << ") "
"smem_size: (" << smem_size_ << ") "
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
return Status::kSuccess;
}
@ -335,7 +339,6 @@ public:
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
// Configure grid and block dimensions
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
dim3 block(GemmKernel::kThreadCount, 1, 1);
dim3 grid = params_.get_grid_dims();
@ -343,9 +346,9 @@ public:
CUTLASS_TRACE_HOST(" "
"grid: (" << grid << "), "
"block: (" << block << "), "
"SMEM: (" << smem_size << ")");
"SMEM: (" << smem_size_ << ")");
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
// Query for errors
cudaError_t result = cudaGetLastError();
@ -398,6 +401,11 @@ thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
/// Kernel dynamic shared memory allocation requirement
template <typename GemmKernel_>
thread_local int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
FastDivmod sk_iters_per_big_block;
FastDivmod sk_iters_per_region;
FastDivmod sk_blocks_per_region;
FastDivmod sm_occupancy;
} div_mod;
@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
", dp_blocks: " << dp_blocks <<
", sk_blocks_per_region: " << sk_blocks_per_region <<
", sk_regions: " << sk_regions <<
", sk_waves: " << sk_waves <<
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
", sm_occupancy: " << sm_occupancy <<
", avail_sms: " << avail_sms <<
", cohort_raster: " << cohort_raster <<
", num_blocks: " << get_num_blocks() <<
"\n\n";
#endif
}
@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {
// We're at (or greater) than GPU occupancy
if (full_waves % sm_occupancy == sm_occupancy - 1)
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
{
// Form the SK wave from the partial wave to get us to full GPU occupancy
// If occupancy is more than one CTA per SM, form the SK wave from the partial
// wave to get us to full GPU occupancy
int max_sk_occupancy = 1;
dp_tiles = full_wave_tiles;
@ -533,7 +534,6 @@ struct ThreadblockSwizzleStreamK {
dp_first_wave_tiles += waveset_excess;
dp_blocks -= (waveset_excess * avail_sms);
}
}
// Setup fast-div/mod for device-side usage
@ -541,7 +541,6 @@ struct ThreadblockSwizzleStreamK {
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
div_mod.sm_occupancy = FastDivmod(sm_occupancy);
}
@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
/// Obtains number of threadblocks per GEMM
int get_num_blocks() const
{
// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
if (work_blocks < avail_sms)
if (work_blocks <= avail_sms * 2)
{
return work_blocks;
}
int gpu_occupancy = sm_occupancy * avail_sms;
int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy;
return gpu_wavesets * gpu_occupancy;
return fast_max(work_blocks, avail_sms * 4);
}
@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
CUTLASS_DEVICE
int get_block_idx() const
{
// Remap the block indices for the first two waves of thread blocks if
// we have multi-occupancy and the grid constitutes four or more waves
int block_idx = RematerializeBlockIdxX();
int gpu_occupancy = avail_sms * sm_occupancy;
int num_blocks = device_num_blocks();
int dest_sm, dest_wave;
div_mod.sm_occupancy(dest_sm, dest_wave, block_idx);
int dest_sm = block_idx / 2;
int dest_wave = block_idx % 2;
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
// remapping the first gpu_occupancy blocks
if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
if ((sm_occupancy > 1) &&
(num_blocks >= avail_sms * 4) &&
(block_idx < avail_sms * 2))
{
block_idx = remapped_block_idx;
}