improve streamk load balance (#743)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
78b30d3191
commit
1e64f153b3
@ -107,6 +107,8 @@ protected:
|
|||||||
/// Kernel SM occupancy (in thread blocks)
|
/// Kernel SM occupancy (in thread blocks)
|
||||||
thread_local static int sm_occupancy_;
|
thread_local static int sm_occupancy_;
|
||||||
|
|
||||||
|
/// Kernel dynamic shared memory allocation requirement
|
||||||
|
thread_local static int smem_size_;
|
||||||
|
|
||||||
/// Initialize static thread-local members for the thread's current device,
|
/// Initialize static thread-local members for the thread's current device,
|
||||||
/// if necessary.
|
/// if necessary.
|
||||||
@ -138,15 +140,15 @@ protected:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update the kernel function's shared memory configuration for the current device
|
// Update the kernel function's shared memory configuration for the current device
|
||||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
smem_size_ = int(sizeof(typename GemmKernel::SharedStorage));
|
||||||
if (smem_size >= (48 << 10))
|
|
||||||
{
|
|
||||||
// Requires more than 48KB: configure for extended, dynamic shared memory
|
|
||||||
|
|
||||||
|
// If requires more than 48KB: configure for extended, dynamic shared memory
|
||||||
|
if (smem_size_ >= (48 << 10))
|
||||||
|
{
|
||||||
cudart_result = cudaFuncSetAttribute(
|
cudart_result = cudaFuncSetAttribute(
|
||||||
Kernel2<GemmKernel>,
|
Kernel2<GemmKernel>,
|
||||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||||
smem_size);
|
smem_size_);
|
||||||
if (cudart_result != cudaSuccess) {
|
if (cudart_result != cudaSuccess) {
|
||||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(cudart_result));
|
||||||
return Status::kErrorInternal;
|
return Status::kErrorInternal;
|
||||||
@ -166,7 +168,7 @@ protected:
|
|||||||
&sm_occupancy_,
|
&sm_occupancy_,
|
||||||
Kernel2<GemmKernel>,
|
Kernel2<GemmKernel>,
|
||||||
GemmKernel::kThreadCount,
|
GemmKernel::kThreadCount,
|
||||||
int(sizeof(typename GemmKernel::SharedStorage)),
|
smem_size_,
|
||||||
cudaOccupancyDisableCachingOverride);
|
cudaOccupancyDisableCachingOverride);
|
||||||
if (cudart_result != cudaSuccess) {
|
if (cudart_result != cudaSuccess) {
|
||||||
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
|
CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags() returned error " << cudaGetErrorString(cudart_result));
|
||||||
@ -179,7 +181,9 @@ protected:
|
|||||||
CUTLASS_TRACE_HOST(" "
|
CUTLASS_TRACE_HOST(" "
|
||||||
"device_ordinal: (" << device_ordinal_ << "), "
|
"device_ordinal: (" << device_ordinal_ << "), "
|
||||||
"device_sms: (" << device_sms_ << "), "
|
"device_sms: (" << device_sms_ << "), "
|
||||||
"sm_occupancy: (" << sm_occupancy_ << ")");
|
"sm_occupancy: (" << sm_occupancy_ << ") "
|
||||||
|
"smem_size: (" << smem_size_ << ") "
|
||||||
|
"GemmKernel::kThreadCount: (" << GemmKernel::kThreadCount << ")");
|
||||||
|
|
||||||
return Status::kSuccess;
|
return Status::kSuccess;
|
||||||
}
|
}
|
||||||
@ -335,7 +339,6 @@ public:
|
|||||||
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
CUTLASS_TRACE_HOST("GemmUniversalBase::run()");
|
||||||
|
|
||||||
// Configure grid and block dimensions
|
// Configure grid and block dimensions
|
||||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
|
||||||
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
dim3 block(GemmKernel::kThreadCount, 1, 1);
|
||||||
dim3 grid = params_.get_grid_dims();
|
dim3 grid = params_.get_grid_dims();
|
||||||
|
|
||||||
@ -343,9 +346,9 @@ public:
|
|||||||
CUTLASS_TRACE_HOST(" "
|
CUTLASS_TRACE_HOST(" "
|
||||||
"grid: (" << grid << "), "
|
"grid: (" << grid << "), "
|
||||||
"block: (" << block << "), "
|
"block: (" << block << "), "
|
||||||
"SMEM: (" << smem_size << ")");
|
"SMEM: (" << smem_size_ << ")");
|
||||||
|
|
||||||
Kernel2<GemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
Kernel2<GemmKernel><<<grid, block, smem_size_, stream>>>(params_);
|
||||||
|
|
||||||
// Query for errors
|
// Query for errors
|
||||||
cudaError_t result = cudaGetLastError();
|
cudaError_t result = cudaGetLastError();
|
||||||
@ -398,6 +401,11 @@ thread_local int GemmUniversalBase<GemmKernel_>::device_sms_ = -1;
|
|||||||
template <typename GemmKernel_>
|
template <typename GemmKernel_>
|
||||||
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
|
thread_local int GemmUniversalBase<GemmKernel_>::sm_occupancy_ = -1;
|
||||||
|
|
||||||
|
/// Kernel dynamic shared memory allocation requirement
|
||||||
|
template <typename GemmKernel_>
|
||||||
|
thread_local int GemmUniversalBase<GemmKernel_>::smem_size_ = -1;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
@ -158,8 +158,6 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
FastDivmod sk_iters_per_big_block;
|
FastDivmod sk_iters_per_big_block;
|
||||||
FastDivmod sk_iters_per_region;
|
FastDivmod sk_iters_per_region;
|
||||||
FastDivmod sk_blocks_per_region;
|
FastDivmod sk_blocks_per_region;
|
||||||
FastDivmod sm_occupancy;
|
|
||||||
|
|
||||||
} div_mod;
|
} div_mod;
|
||||||
|
|
||||||
|
|
||||||
@ -188,6 +186,7 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
", dp_blocks: " << dp_blocks <<
|
", dp_blocks: " << dp_blocks <<
|
||||||
", sk_blocks_per_region: " << sk_blocks_per_region <<
|
", sk_blocks_per_region: " << sk_blocks_per_region <<
|
||||||
", sk_regions: " << sk_regions <<
|
", sk_regions: " << sk_regions <<
|
||||||
|
", sk_waves: " << sk_waves <<
|
||||||
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
|
", sk_iters_per_normal_block: " << sk_iters_per_normal_block <<
|
||||||
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
", sk_big_blocks_per_region: " << sk_big_blocks_per_region <<
|
||||||
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
", dp_first_wave_tiles: " << dp_first_wave_tiles <<
|
||||||
@ -200,6 +199,7 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
", sm_occupancy: " << sm_occupancy <<
|
", sm_occupancy: " << sm_occupancy <<
|
||||||
", avail_sms: " << avail_sms <<
|
", avail_sms: " << avail_sms <<
|
||||||
", cohort_raster: " << cohort_raster <<
|
", cohort_raster: " << cohort_raster <<
|
||||||
|
", num_blocks: " << get_num_blocks() <<
|
||||||
"\n\n";
|
"\n\n";
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@ -316,9 +316,10 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
|
|
||||||
// We're at (or greater) than GPU occupancy
|
// We're at (or greater) than GPU occupancy
|
||||||
|
|
||||||
if (full_waves % sm_occupancy == sm_occupancy - 1)
|
if ((sm_occupancy > 1 ) && (full_waves % sm_occupancy == sm_occupancy - 1))
|
||||||
{
|
{
|
||||||
// Form the SK wave from the partial wave to get us to full GPU occupancy
|
// If occupancy is more than one CTA per SM, form the SK wave from the partial
|
||||||
|
// wave to get us to full GPU occupancy
|
||||||
int max_sk_occupancy = 1;
|
int max_sk_occupancy = 1;
|
||||||
|
|
||||||
dp_tiles = full_wave_tiles;
|
dp_tiles = full_wave_tiles;
|
||||||
@ -533,7 +534,6 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
dp_first_wave_tiles += waveset_excess;
|
dp_first_wave_tiles += waveset_excess;
|
||||||
dp_blocks -= (waveset_excess * avail_sms);
|
dp_blocks -= (waveset_excess * avail_sms);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup fast-div/mod for device-side usage
|
// Setup fast-div/mod for device-side usage
|
||||||
@ -541,7 +541,6 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
|
div_mod.tiled_shape_n = FastDivmod(tiled_shape.n());
|
||||||
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
div_mod.tiled_cohort_shape_n = FastDivmod(tiled_cohort_shape.n());
|
||||||
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
|
div_mod.iters_per_tile = FastDivmod(iters_per_tile);
|
||||||
div_mod.sm_occupancy = FastDivmod(sm_occupancy);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -602,21 +601,14 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
/// Obtains number of threadblocks per GEMM
|
/// Obtains number of threadblocks per GEMM
|
||||||
int get_num_blocks() const
|
int get_num_blocks() const
|
||||||
{
|
{
|
||||||
// int reduction_waves = (reduction_blocks + avail_sms - 1) / avail_sms;
|
|
||||||
// return ((sk_waves + reduction_waves) * avail_sms) + dp_blocks;
|
|
||||||
|
|
||||||
|
|
||||||
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
int work_blocks = (sk_waves * avail_sms) + dp_blocks + reduction_blocks;
|
||||||
|
|
||||||
if (work_blocks < avail_sms)
|
if (work_blocks <= avail_sms * 2)
|
||||||
{
|
{
|
||||||
return work_blocks;
|
return work_blocks;
|
||||||
}
|
}
|
||||||
|
|
||||||
int gpu_occupancy = sm_occupancy * avail_sms;
|
return fast_max(work_blocks, avail_sms * 4);
|
||||||
int gpu_wavesets = (work_blocks + gpu_occupancy - 1) / gpu_occupancy;
|
|
||||||
return gpu_wavesets * gpu_occupancy;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -686,18 +678,18 @@ struct ThreadblockSwizzleStreamK {
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
int get_block_idx() const
|
int get_block_idx() const
|
||||||
{
|
{
|
||||||
|
// Remap the block indices for the first two waves of thread blocks if
|
||||||
|
// we have multi-occupancy and the grid constitutes four or more waves
|
||||||
|
|
||||||
int block_idx = RematerializeBlockIdxX();
|
int block_idx = RematerializeBlockIdxX();
|
||||||
|
|
||||||
int gpu_occupancy = avail_sms * sm_occupancy;
|
|
||||||
int num_blocks = device_num_blocks();
|
int num_blocks = device_num_blocks();
|
||||||
int dest_sm, dest_wave;
|
int dest_sm = block_idx / 2;
|
||||||
|
int dest_wave = block_idx % 2;
|
||||||
div_mod.sm_occupancy(dest_sm, dest_wave, block_idx);
|
|
||||||
|
|
||||||
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
int remapped_block_idx = dest_sm + (dest_wave * avail_sms);
|
||||||
|
|
||||||
// remapping the first gpu_occupancy blocks
|
if ((sm_occupancy > 1) &&
|
||||||
if ((block_idx < gpu_occupancy) && (num_blocks > gpu_occupancy))
|
(num_blocks >= avail_sms * 4) &&
|
||||||
|
(block_idx < avail_sms * 2))
|
||||||
{
|
{
|
||||||
block_idx = remapped_block_idx;
|
block_idx = remapped_block_idx;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user