Fx parallel split-k (#1116)
This commit is contained in:
parent
4082fed85a
commit
ff02da2667
@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// A and B uses the same alignment in the generator.py
|
||||
int alignment = gemm_desc.A.alignment;
|
||||
// gemm operation for same compute capability and max operand alignment
|
||||
int alignment = std::max(
|
||||
gemm_desc.A.alignment,
|
||||
gemm_desc.B.alignment);
|
||||
|
||||
// gemm operation for same compute capability and iterator algorithm
|
||||
GemmPreferenceKey preference_key(
|
||||
gemm_desc.tile_description.minimum_compute_capability,
|
||||
alignment);
|
||||
|
||||
return find_gemm_operation(operators_it, preference_key);
|
||||
auto it = operators_it->second.find(preference_key);
|
||||
|
||||
if(it == operators_it->second.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// return matching gemm operation (same tile shape, stages, warp count, and instruction)
|
||||
for (auto op : it->second) {
|
||||
if (op->description().tile_description == operation->description().tile_description) {
|
||||
return op;
|
||||
}
|
||||
}
|
||||
|
||||
// return nullptr if no matching gemm operation found for parallel split-k reduction
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user