Fx parallel split-k (#1116)
This commit is contained in:
parent
4082fed85a
commit
ff02da2667
@ -1168,15 +1168,30 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
// A and B uses the same alignment in the generator.py
|
// gemm operation for same compute capability and max operand alignment
|
||||||
int alignment = gemm_desc.A.alignment;
|
int alignment = std::max(
|
||||||
|
gemm_desc.A.alignment,
|
||||||
|
gemm_desc.B.alignment);
|
||||||
|
|
||||||
// gemm operation for same compute capability and iterator algorithm
|
|
||||||
GemmPreferenceKey preference_key(
|
GemmPreferenceKey preference_key(
|
||||||
gemm_desc.tile_description.minimum_compute_capability,
|
gemm_desc.tile_description.minimum_compute_capability,
|
||||||
alignment);
|
alignment);
|
||||||
|
|
||||||
return find_gemm_operation(operators_it, preference_key);
|
auto it = operators_it->second.find(preference_key);
|
||||||
|
|
||||||
|
if(it == operators_it->second.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// return matching gemm operation (same tile shape, stages, warp count, and instruction)
|
||||||
|
for (auto op : it->second) {
|
||||||
|
if (op->description().tile_description == operation->description().tile_description) {
|
||||||
|
return op;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// return nullptr if no matching gemm operation found for parallel split-k reduction
|
||||||
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user