Fx parallel split-k (#1116)

This commit is contained in:
Manish Gupta 2023-10-06 09:02:40 -07:00 committed by GitHub
parent 4082fed85a
commit ff02da2667
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope
return nullptr; return nullptr;
} }
// A and B uses the same alignment in the generator.py // gemm operation for same compute capability and max operand alignment
int alignment = gemm_desc.A.alignment; int alignment = std::max(
gemm_desc.A.alignment,
gemm_desc.B.alignment);
// gemm operation for same compute capability and iterator algorithm
GemmPreferenceKey preference_key( GemmPreferenceKey preference_key(
gemm_desc.tile_description.minimum_compute_capability, gemm_desc.tile_description.minimum_compute_capability,
alignment); alignment);
return find_gemm_operation(operators_it, preference_key); auto it = operators_it->second.find(preference_key);
if(it == operators_it->second.end()) {
return nullptr;
}
// return matching gemm operation (same tile shape, stages, warp count, and instruction)
for (auto op : it->second) {
if (op->description().tile_description == operation->description().tile_description) {
return op;
}
}
// return nullptr if no matching gemm operation found for parallel split-k reduction
return nullptr;
} }
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////