diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index bdea2f49..a24e4e03 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope return nullptr; } - // A and B uses the same alignment in the generator.py - int alignment = gemm_desc.A.alignment; + // gemm operation for same compute capability and max operand alignment + int alignment = std::max( + gemm_desc.A.alignment, + gemm_desc.B.alignment); - // gemm operation for same compute capability and iterator algorithm GemmPreferenceKey preference_key( gemm_desc.tile_description.minimum_compute_capability, alignment); - return find_gemm_operation(operators_it, preference_key); + auto it = operators_it->second.find(preference_key); + + if(it == operators_it->second.end()) { + return nullptr; + } + + // return matching gemm operation (same tile shape, stages, warp count, and instruction) + for (auto op : it->second) { + if (op->description().tile_description == operation->description().tile_description) { + return op; + } + } + + // return nullptr if no matching gemm operation found for parallel split-k reduction + return nullptr; } /////////////////////////////////////////////////////////////////////////////////////////////////