diff --git a/examples/24_gemm_grouped/gemm_grouped.cu b/examples/24_gemm_grouped/gemm_grouped.cu index cfeb1ba5..a32c80d7 100644 --- a/examples/24_gemm_grouped/gemm_grouped.cu +++ b/examples/24_gemm_grouped/gemm_grouped.cu @@ -756,12 +756,6 @@ public: /// Returns the number of threadblocks to launch if the kernel can run on the target /// device. Otherwise, returns zero. int sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); - cudaDeviceProp properties; int device_idx; cudaError_t result = cudaGetDevice(&device_idx); @@ -776,9 +770,10 @@ public: throw std::runtime_error("cudaGetDeviceProperties() failed"); } - int occupancy = std::min(2, int(properties.sharedMemPerMultiprocessor / smem_size)); + int occupancy = Gemm::maximum_active_blocks(); return properties.multiProcessorCount * occupancy; + } diff --git a/include/cutlass/gemm/device/gemm_grouped.h b/include/cutlass/gemm/device/gemm_grouped.h index f489ba93..628a56b0 100644 --- a/include/cutlass/gemm/device/gemm_grouped.h +++ b/include/cutlass/gemm/device/gemm_grouped.h @@ -139,70 +139,40 @@ public: CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); - int max_active_blocks = -1; int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - if (smem_size <= (48 << 10)) { + cudaError_t result; + if (smem_size > (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error " + << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - if (result == cudaSuccess) { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else { - - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, - Kernel, - GemmKernel::kThreadCount, - 0); - - if (result != cudaSuccess) { - - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " - << cudaGetErrorString(result)); - - return -1; - } - - if (smem_capacity < 0) { - int device_idx = 0; - result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) { - return -1; - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) { - return -1; - } - - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } - - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - - return occupancy; + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; } - CUTLASS_TRACE_HOST(" returning internal error"); - - return -1; + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; } /// Initializes GEMM state from arguments. diff --git a/test/unit/gemm/device/testbed_grouped.h b/test/unit/gemm/device/testbed_grouped.h index 29cffda3..2641e8d1 100644 --- a/test/unit/gemm/device/testbed_grouped.h +++ b/test/unit/gemm/device/testbed_grouped.h @@ -419,12 +419,6 @@ struct TestbedGrouped { /// Returns the number of threadblocks to launch if the kernel can run on the target /// device. Otherwise, returns zero. int sufficient() const { - // - // Determine SMEM requirements and waive if not satisfied - // - - int smem_size = int(sizeof(typename Gemm::GemmKernel::SharedStorage)); - cudaDeviceProp properties; int device_idx; cudaError_t result = cudaGetDevice(&device_idx); @@ -439,7 +433,7 @@ struct TestbedGrouped { throw std::runtime_error("cudaGetDeviceProperties() failed"); } - int occupancy = std::min(2, int(properties.sharedMemPerMultiprocessor / smem_size)); + int occupancy = Gemm::maximum_active_blocks(); return properties.multiProcessorCount * occupancy; }