fix split_k_mode and add reduction kernel for f16 input/accum/output (#896)
This commit is contained in:
parent
bc36122c3f
commit
660a05f581
@ -42,6 +42,7 @@ namespace library {
|
|||||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// CUTLASS Reduction Instances //
|
// CUTLASS Reduction Instances //
|
||||||
///////////////////////////////////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest);
|
||||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
|
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest);
|
||||||
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
|
void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest);
|
||||||
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
|
void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest);
|
||||||
@ -52,6 +53,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
|||||||
//
|
//
|
||||||
void initialize_all_reduction_op(Manifest &manifest) {
|
void initialize_all_reduction_op(Manifest &manifest) {
|
||||||
|
|
||||||
|
initialize_reduce_add_linear_combination_f16_f16_f16(manifest);
|
||||||
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
|
initialize_reduce_add_linear_combination_f32_f32_f16(manifest);
|
||||||
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
|
initialize_reduce_add_linear_combination_f32_f32_f32(manifest);
|
||||||
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);
|
initialize_reduce_add_linear_combination_f64_f64_f64(manifest);
|
||||||
|
@ -43,6 +43,40 @@ namespace library {
|
|||||||
|
|
||||||
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
|
// naming convention initialize_reduce_[ReductionOp]_[EpilogueOp]_[ElementWorkspace]_[ElementAccumulator]_[ElementOutput]
|
||||||
|
|
||||||
|
void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest) {
|
||||||
|
|
||||||
|
using ElementWorkspace = cutlass::half_t;
|
||||||
|
using ElementAccumulator = cutlass::half_t;
|
||||||
|
using ElementOutput = cutlass::half_t;
|
||||||
|
using ElementCompute = cutlass::half_t;
|
||||||
|
|
||||||
|
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
|
ElementOutput,
|
||||||
|
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||||
|
ElementAccumulator,
|
||||||
|
ElementCompute
|
||||||
|
>;
|
||||||
|
|
||||||
|
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
|
||||||
|
ElementAccumulator,
|
||||||
|
typename EpilogueOutputOp::ElementAccumulator,
|
||||||
|
EpilogueOutputOp::kCount
|
||||||
|
>;
|
||||||
|
|
||||||
|
using Operation_reduce_add_linear_combination_f16_f16_f16 = cutlass::reduction::device::ReduceSplitK<
|
||||||
|
cutlass::reduction::kernel::ReduceSplitK<
|
||||||
|
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
|
||||||
|
EpilogueOutputOp,
|
||||||
|
ReductionOp
|
||||||
|
>
|
||||||
|
>;
|
||||||
|
|
||||||
|
manifest.append(new ReductionOperation<
|
||||||
|
Operation_reduce_add_linear_combination_f16_f16_f16>(
|
||||||
|
"reduce_add_linear_combination_f16_f16_f16"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
||||||
|
|
||||||
using ElementWorkspace = float;
|
using ElementWorkspace = float;
|
||||||
|
@ -62,7 +62,6 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
|||||||
library::OperationKind::kGemm,
|
library::OperationKind::kGemm,
|
||||||
{
|
{
|
||||||
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
||||||
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
|
|
||||||
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
||||||
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
||||||
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
||||||
@ -71,6 +70,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
|||||||
{ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"},
|
{ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"},
|
||||||
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
|
{ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"},
|
||||||
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
|
{ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"},
|
||||||
|
{ArgumentTypeID::kEnumerated, {"split_k_mode", "split-k-mode"}, "Variant of split K mode(serial, parallel)"},
|
||||||
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
{ArgumentTypeID::kInteger, {"split_k_slices", "split-k-slices"}, "Number of partitions of K dimension"},
|
||||||
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
{ArgumentTypeID::kInteger, {"batch_count", "batch-count"}, "Number of GEMMs computed in one batch"},
|
||||||
},
|
},
|
||||||
@ -298,8 +298,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
|||||||
|
|
||||||
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
||||||
|
|
||||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
|
||||||
|
|
||||||
set_argument(result, "A", problem_space,
|
set_argument(result, "A", problem_space,
|
||||||
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
||||||
|
|
||||||
@ -313,6 +311,7 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
|||||||
set_argument(result, "n", problem_space, n);
|
set_argument(result, "n", problem_space, n);
|
||||||
set_argument(result, "k", problem_space, k);
|
set_argument(result, "k", problem_space, k);
|
||||||
|
|
||||||
|
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||||
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
set_argument(result, "split_k_slices", problem_space, split_k_slices);
|
||||||
set_argument(result, "batch_count", problem_space, batch_count);
|
set_argument(result, "batch_count", problem_space, batch_count);
|
||||||
|
|
||||||
|
@ -66,9 +66,8 @@ public:
|
|||||||
|
|
||||||
/// Problem structure obtained from problem space
|
/// Problem structure obtained from problem space
|
||||||
struct GemmProblem {
|
struct GemmProblem {
|
||||||
|
|
||||||
cutlass::library::GemmUniversalMode mode;
|
cutlass::library::GemmUniversalMode mode;
|
||||||
cutlass::library::SplitKMode split_k_mode;
|
|
||||||
int64_t m;
|
int64_t m;
|
||||||
int64_t n;
|
int64_t n;
|
||||||
int64_t k;
|
int64_t k;
|
||||||
@ -77,6 +76,8 @@ public:
|
|||||||
int64_t ldc;
|
int64_t ldc;
|
||||||
std::vector<uint8_t> alpha;
|
std::vector<uint8_t> alpha;
|
||||||
std::vector<uint8_t> beta;
|
std::vector<uint8_t> beta;
|
||||||
|
|
||||||
|
cutlass::library::SplitKMode split_k_mode;
|
||||||
int split_k_slices;
|
int split_k_slices;
|
||||||
int batch_count;
|
int batch_count;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user