From ff02da266713bd3365aed65c552412e126c040cb Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Fri, 6 Oct 2023 09:02:40 -0700 Subject: [PATCH] Fx parallel split-k (#1116) --- tools/library/src/handle.cu | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index bdea2f49..a24e4e03 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope return nullptr; } - // A and B uses the same alignment in the generator.py - int alignment = gemm_desc.A.alignment; + // gemm operation for same compute capability and max operand alignment + int alignment = std::max( + gemm_desc.A.alignment, + gemm_desc.B.alignment); - // gemm operation for same compute capability and iterator algorithm GemmPreferenceKey preference_key( gemm_desc.tile_description.minimum_compute_capability, alignment); - return find_gemm_operation(operators_it, preference_key); + auto it = operators_it->second.find(preference_key); + + if(it == operators_it->second.end()) { + return nullptr; + } + + // return matching gemm operation (same tile shape, stages, warp count, and instruction) + for (auto op : it->second) { + if (op->description().tile_description == operation->description().tile_description) { + return op; + } + } + + // return nullptr if no matching gemm operation found for parallel split-k reduction + return nullptr; } /////////////////////////////////////////////////////////////////////////////////////////////////