Fix occupancy calculation for grouped GEMM (#532)
This commit is contained in:
parent
25e26a6e51
commit
fa56763c25
@ -756,12 +756,6 @@ public:
|
||||
/// Returns the number of threadblocks to launch if the kernel can run on the target
|
||||
/// device. Otherwise, returns zero.
|
||||
int sufficient() const {
|
||||
//
|
||||
// Determine SMEM requirements and waive if not satisfied
|
||||
//
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
|
||||
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
@ -776,9 +770,10 @@ public:
|
||||
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
||||
}
|
||||
|
||||
int occupancy = std::min(2, int(properties.sharedMemPerMultiprocessor / smem_size));
|
||||
int occupancy = Gemm::maximum_active_blocks();
|
||||
|
||||
return properties.multiProcessorCount * occupancy;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -139,70 +139,40 @@ public:
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()");
|
||||
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = int(sizeof(typename GemmKernel::SharedStorage));
|
||||
|
||||
CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes");
|
||||
|
||||
if (smem_size <= (48 << 10)) {
|
||||
cudaError_t result;
|
||||
if (smem_size > (48 << 10)) {
|
||||
result = cudaFuncSetAttribute(Kernel<GemmKernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
smem_size);
|
||||
|
||||
if (result == cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
// Query assuming zero shared memory then compute occupancy limit based on SMEM
|
||||
cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
Kernel<GemmKernel>,
|
||||
GemmKernel::kThreadCount,
|
||||
0);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (smem_capacity < 0) {
|
||||
int device_idx = 0;
|
||||
result = cudaGetDevice(&device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
cudaDeviceProp properties;
|
||||
result = cudaGetDeviceProperties(&properties, device_idx);
|
||||
|
||||
if (result != cudaSuccess) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
smem_capacity = static_cast<int>(properties.sharedMemPerMultiprocessor);
|
||||
}
|
||||
|
||||
int occupancy = std::min(max_active_blocks, smem_capacity / smem_size);
|
||||
|
||||
CUTLASS_TRACE_HOST(" occupancy: " << occupancy);
|
||||
|
||||
return occupancy;
|
||||
if (result != cudaSuccess) {
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" returning internal error");
|
||||
|
||||
return -1;
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
|
@ -419,12 +419,6 @@ struct TestbedGrouped {
|
||||
/// Returns the number of threadblocks to launch if the kernel can run on the target
|
||||
/// device. Otherwise, returns zero.
|
||||
int sufficient() const {
|
||||
//
|
||||
// Determine SMEM requirements and waive if not satisfied
|
||||
//
|
||||
|
||||
int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage));
|
||||
|
||||
cudaDeviceProp properties;
|
||||
int device_idx;
|
||||
cudaError_t result = cudaGetDevice(&device_idx);
|
||||
@ -439,7 +433,7 @@ struct TestbedGrouped {
|
||||
throw std::runtime_error("cudaGetDeviceProperties() failed");
|
||||
}
|
||||
|
||||
int occupancy = std::min(2, int(properties.sharedMemPerMultiprocessor / smem_size));
|
||||
int occupancy = Gemm::maximum_active_blocks();
|
||||
|
||||
return properties.multiProcessorCount * occupancy;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user