fix split_k_mode and add reduction kernel for f16 input/accum/output (#896)
This commit is contained in:
parent
bc36122c3f
commit
660a05f581
@ -42,6 +42,7 @@ namespace library {
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// CUTLASS Reduction Instances //
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
|
||||
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
|
||||
@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
||||
//
|
||||
void initialize_all_reduction_op(Manifest &manifest) {
|
||||
|
||||
initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
|
||||
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
|
||||
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
|
||||
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);
|
||||
|
@ -43,6 +43,40 @@ namespace library {
|
||||
|
||||
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
|
||||
|
||||
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {
|
||||
|
||||
using ElementWorkspace = cutlass::half_t;
|
||||
using ElementAccumulator = cutlass::half_t;
|
||||
using ElementOutput = cutlass::half_t;
|
||||
using ElementCompute = cutlass::half_t;
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
|
||||
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
||||
ElementAccumulator,
|
||||
typename EpilogueOutputOp::ElementAccumulator,
|
||||
EpilogueOutputOp::kCount
|
||||
>;
|
||||
|
||||
using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
|
||||
cutlass::reduction::kernel::ReduceSplitK<
|
||||
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
||||
EpilogueOutputOp,
|
||||
ReductionOp
|
||||
>
|
||||
>;
|
||||
|
||||
manifest.append(new ReductionOperation<
|
||||
Operation_reduce_add_linear_combination_f16_f16_f16>(
|
||||
"reduce_add_linear_combination_f16_f16_f16"
|
||||
));
|
||||
}
|
||||
|
||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
||||
|
||||
using ElementWorkspace = float;
|
||||
|
@ -62,7 +62,6 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
library::OperationKind::kGemm,
|
||||
{
|
||||
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
||||
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
|
||||
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
||||
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
||||
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
||||
@ -71,6 +70,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
{ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"},
|
||||
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
|
||||
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
|
||||
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
|
||||
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
||||
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
||||
},
|
||||
@ -298,8 +298,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
|
||||
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
||||
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
|
||||
set_argument(result, "A", problem_space,
|
||||
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
||||
|
||||
@ -313,6 +311,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
set_argument(result, "n", problem_space, n);
|
||||
set_argument(result, "k", problem_space, k);
|
||||
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
||||
set_argument(result, "batch_count", problem_space, batch_count);
|
||||
|
||||
|
@ -68,7 +68,6 @@ public:
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode;
|
||||
cutlass::library::SplitKMode split_k_mode;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
@ -77,6 +76,8 @@ public:
|
||||
int64_t ldc;
|
||||
std::vector<uint8_t> alpha;
|
||||
std::vector<uint8_t> beta;
|
||||
|
||||
cutlass::library::SplitKMode split_k_mode;
|
||||
int split_k_slices;
|
||||
int batch_count;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user