fix split_k_mode and add reduction kernel for f16 input/accum/output (#896)

This commit is contained in:
Manish Gupta 2023-03-30 12:31:08 -07:00 committed by GitHub
parent bc36122c3f
commit 660a05f581
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 41 additions and 5 deletions

View File

@ -42,6 +42,7 @@ namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////
// CUTLASS Reduction Instances //
///////////////////////////////////////////////////////////////////////////////////////////////
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
//
void initialize_all_reduction_op(Manifest &manifest) {
initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);

View File

@ -43,6 +43,40 @@ namespace library {
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {
using ElementWorkspace = cutlass::half_t;
using ElementAccumulator = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using ElementCompute = cutlass::half_t;
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
typename EpilogueOutputOp::ElementAccumulator,
EpilogueOutputOp::kCount
>;
using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>
>;
manifest.append(new ReductionOperation<
Operation_reduce_add_linear_combination_f16_f16_f16>(
"reduce_add_linear_combination_f16_f16_f16"
));
}
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
using ElementWorkspace = float;

View File

@ -62,7 +62,6 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
library::OperationKind::kGemm,
{
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
@ -71,6 +70,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
{ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"},
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
},
@ -298,8 +298,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "A", problem_space,
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
@ -313,6 +311,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "n", problem_space, n);
set_argument(result, "k", problem_space, k);
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "split_k_slices", problem_space, split_k_slices);
set_argument(result, "batch_count", problem_space, batch_count);

View File

@ -66,9 +66,8 @@ public:
/// Problem structure obtained from problem space
struct GemmProblem {
cutlass::library::GemmUniversalMode mode;
cutlass::library::SplitKMode split_k_mode;
int64_t m;
int64_t n;
int64_t k;
@ -77,6 +76,8 @@ public:
int64_t ldc;
std::vector<uint8_t> alpha;
std::vector<uint8_t> beta;
cutlass::library::SplitKMode split_k_mode;
int split_k_slices;
int batch_count;