Support parallel split K mode for porfiling (#277)
* Support parallel split K mode for porfiling Signed-off-by: Peter Han <fujun.han@iluvatar.ai> * Parallel Split K support 1. find gemm kernel by preference key 2. switch m n for redution kernel Signed-off-by: Peter Han <fujun.han@iluvatar.ai> * parallel splitk for fp16 gemm * add one missing file Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
c3353add63
commit
1e4703cbab
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -338,6 +338,9 @@ using HandlePtr = std::unique_ptr<Handle>;
|
||||
/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace
|
||||
Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation);
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace
|
||||
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation);
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
} // namespace cutlass
|
||||
|
@ -590,7 +590,8 @@ public:
|
||||
void const *configuration) const = 0;
|
||||
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration) const = 0;
|
||||
void const *configuration,
|
||||
void const *arguments = nullptr) const = 0;
|
||||
|
||||
virtual Status initialize(
|
||||
void const *configuration,
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -272,7 +272,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -266,7 +266,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -242,7 +242,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
@ -443,7 +444,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
@ -569,7 +571,7 @@ protected:
|
||||
operator_args.ldb = (configuration->ldb);
|
||||
operator_args.ldc = (configuration->ldc);
|
||||
operator_args.ldd = (configuration->ldd);
|
||||
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
@ -649,7 +651,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
@ -661,6 +664,14 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
status = update_arguments_(
|
||||
args,
|
||||
static_cast<GemmUniversalArguments const *>(arguments_ptr));
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t size = Operator::get_workspace_size(args);
|
||||
|
||||
return size;
|
||||
@ -855,7 +866,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
@ -1055,7 +1067,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -1098,6 +1098,59 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace
|
||||
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) {
|
||||
|
||||
GemmDescription const &gemm_desc =
|
||||
static_cast<GemmDescription const &>(operation->description());
|
||||
|
||||
// if the curren gemm operation accumulator and output data type match return operation
|
||||
if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) {
|
||||
return operation;
|
||||
}
|
||||
|
||||
// find gemm operation to match gemm output and reduction workspace data type
|
||||
GemmFunctionalKey key(
|
||||
library::Provider::kCUTLASS,
|
||||
gemm_desc.gemm_kind,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||
gemm_desc.element_epilogue,
|
||||
gemm_desc.A.element,
|
||||
gemm_desc.A.layout,
|
||||
gemm_desc.transform_A,
|
||||
gemm_desc.B.element,
|
||||
gemm_desc.B.layout,
|
||||
gemm_desc.transform_B,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator);
|
||||
|
||||
// gemm operation table
|
||||
auto gemm_operations = Singleton::get().operation_table.gemm_operations;
|
||||
|
||||
// find ConvFunctionalKey in gemm operation table
|
||||
auto operators_it = gemm_operations.find(key);
|
||||
|
||||
if (operators_it == gemm_operations.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (operators_it->second.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// A and B uses the same alignment in the generator.py
|
||||
int alignment = gemm_desc.A.alignment;
|
||||
|
||||
// gemm operation for same compute capability and iterator algorithm
|
||||
GemmPreferenceKey preference_key(
|
||||
gemm_desc.tile_description.minimum_compute_capability,
|
||||
alignment);
|
||||
|
||||
return find_gemm_operation(operators_it, preference_key);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace library
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -46,7 +46,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
@ -81,7 +81,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) {
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
@ -115,7 +115,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
||||
|
||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||
ElementOutput,
|
||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
||||
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||
ElementAccumulator,
|
||||
ElementCompute
|
||||
>;
|
||||
@ -140,6 +140,5 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include <iostream>
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||
|
||||
@ -180,7 +181,8 @@ public:
|
||||
|
||||
/// Gets the device-side workspace
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration_ptr) const {
|
||||
void const *configuration_ptr,
|
||||
void const *arguments_ptr = nullptr) const {
|
||||
|
||||
OperatorArguments args;
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -403,7 +403,8 @@ public:
|
||||
}
|
||||
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration) const {
|
||||
void const *configuration,
|
||||
void const *arguments = nullptr) const {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -161,7 +161,8 @@ public:
|
||||
}
|
||||
|
||||
virtual uint64_t get_device_workspace_size(
|
||||
void const *configuration) const {
|
||||
void const *configuration,
|
||||
void const *arguments = nullptr) const {
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -37,6 +37,7 @@
|
||||
#include "gemm_operation_profiler.h"
|
||||
#include "gpu_timer.h"
|
||||
|
||||
#include "cutlass/library/singleton.h"
|
||||
#include "cutlass/library/library.h"
|
||||
#include "cutlass/library/handle.h"
|
||||
|
||||
@ -55,6 +56,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
||||
library::OperationKind::kGemm,
|
||||
{
|
||||
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
||||
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
|
||||
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
||||
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
||||
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
||||
@ -100,6 +102,9 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
|
||||
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
|
||||
|
||||
<< "Profile a particular problem size with split K and paralell reduction:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n"
|
||||
|
||||
<< "Using various input value distribution:\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
|
||||
<< " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n"
|
||||
@ -155,8 +160,17 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
// default value
|
||||
this->k = 1024;
|
||||
}
|
||||
|
||||
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
|
||||
// defualt value
|
||||
this->split_k_mode = library::SplitKMode::kSerial;
|
||||
}
|
||||
|
||||
this->mode = library::GemmUniversalMode::kGemm;
|
||||
if(this->split_k_mode == library::SplitKMode::kParallel) {
|
||||
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
|
||||
}
|
||||
|
||||
if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) {
|
||||
// default value
|
||||
this->split_k_slices = 1;
|
||||
@ -165,8 +179,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
|
||||
// default value
|
||||
this->batch_count = 1;
|
||||
}
|
||||
else if (this->batch_count > 1) {
|
||||
} else if (this->batch_count > 1) {
|
||||
this->mode = library::GemmUniversalMode::kBatched;
|
||||
}
|
||||
|
||||
@ -210,7 +223,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
this->lda = DeviceAllocation::get_packed_layout(
|
||||
operation_desc.A.layout, {int(this->m), int(this->k)}).front();
|
||||
|
||||
@ -279,6 +292,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
||||
|
||||
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
||||
|
||||
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||
|
||||
set_argument(result, "A", problem_space,
|
||||
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
||||
|
||||
@ -321,7 +336,7 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
}
|
||||
|
||||
Status status = problem_.parse(operation_desc, problem_space, problem);
|
||||
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
@ -350,6 +365,13 @@ Status GemmOperationProfiler::initialize_configuration(
|
||||
gemm_workspace_.arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
|
||||
// initialize reduction operation for parallel splitKMode
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
if (!initialize_reduction_configuration_(operation, problem)) {
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||
|
||||
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
|
||||
@ -366,7 +388,7 @@ void GemmOperationProfiler::initialize_result_(
|
||||
result.disposition = Disposition::kNotRun;
|
||||
result.status = Status::kSuccess;
|
||||
result.operation_name = operation_desc.name;
|
||||
|
||||
|
||||
problem_.initialize_result(result, operation_desc, problem_space);
|
||||
|
||||
OperationProfiler::initialize_result_(result, operation_desc, problem_space);
|
||||
@ -377,6 +399,51 @@ void GemmOperationProfiler::initialize_result_(
|
||||
|
||||
}
|
||||
|
||||
/// Initialize redution problem dimentions and library::Operation
|
||||
bool GemmOperationProfiler::initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
library::GemmDescription const &gemm_desc =
|
||||
static_cast<library::GemmDescription const&>(operation->description());
|
||||
|
||||
if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/// initialize library::ReductionConfiguration
|
||||
gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn();
|
||||
gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices);
|
||||
gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product();
|
||||
gemm_workspace_.reduction_configuration.ldw = problem_.ldc;
|
||||
gemm_workspace_.reduction_configuration.lds = problem_.ldc;
|
||||
gemm_workspace_.reduction_configuration.ldd = problem_.ldc;
|
||||
|
||||
// find reduction operation
|
||||
library::ReductionFunctionalKey reduction_key(
|
||||
library::Provider::kCUTLASS,
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace
|
||||
gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator
|
||||
gemm_desc.C.element, // element output
|
||||
gemm_desc.element_epilogue // element coumpute
|
||||
);
|
||||
|
||||
auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);
|
||||
|
||||
if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// initialize reduction operation required for parallel split-k operator
|
||||
reduction_op_ = reduction_it->second;
|
||||
|
||||
// reduction operation found and initialized
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Initializes workspace
|
||||
Status GemmOperationProfiler::initialize_workspace(
|
||||
Options const &options,
|
||||
@ -385,7 +452,15 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace const &problem_space,
|
||||
ProblemSpace::Problem const &problem) {
|
||||
|
||||
|
||||
library::Operation const* underlying_operation = operation;
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
library::GemmDescription const &operation_desc =
|
||||
static_cast<library::GemmDescription const &>(operation->description());
|
||||
|
||||
@ -455,8 +530,12 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
);
|
||||
|
||||
gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data());
|
||||
}
|
||||
|
||||
gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride();
|
||||
gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride();
|
||||
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||
}
|
||||
|
||||
//
|
||||
// Initialize the CUTLASS operation
|
||||
@ -467,16 +546,35 @@ Status GemmOperationProfiler::initialize_workspace(
|
||||
|
||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||
|
||||
uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||
uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||
gemm_workspace_.host_workspace.resize(workspace_size, 0);
|
||||
|
||||
workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration);
|
||||
workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration,
|
||||
&gemm_workspace_.arguments);
|
||||
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
|
||||
|
||||
status = operation->initialize(
|
||||
status = underlying_operation->initialize(
|
||||
&gemm_workspace_.configuration,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration);
|
||||
gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0);
|
||||
|
||||
status = reduction_op_->initialize(
|
||||
&gemm_workspace_.reduction_configuration,
|
||||
gemm_workspace_.reduction_host_workspace.data(),
|
||||
nullptr);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -527,11 +625,34 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
|
||||
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
|
||||
|
||||
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
|
||||
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
|
||||
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
|
||||
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
}
|
||||
|
||||
//
|
||||
// Run the CUTLASS operation
|
||||
//
|
||||
|
||||
results_.back().status = operation->run(
|
||||
// initialize gemm underlying operation to handle parallel reduction
|
||||
library::Operation const * underlying_operation = operation;
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||
results_.back().disposition = Disposition::kFailed;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
results_.back().status = underlying_operation->run(
|
||||
&gemm_workspace_.arguments,
|
||||
gemm_workspace_.host_workspace.data(),
|
||||
gemm_workspace_.device_workspace.data());
|
||||
@ -541,6 +662,19 @@ bool GemmOperationProfiler::verify_cutlass(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Run parallel reduction kernel for parallel split_k_mode
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
results_.back().status = reduction_op_->run(
|
||||
&gemm_workspace_.reduction_arguments,
|
||||
gemm_workspace_.reduction_host_workspace.data(),
|
||||
nullptr);
|
||||
|
||||
if (results_.back().status != Status::kSuccess) {
|
||||
results_.back().disposition = Disposition::kFailed;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
if (result != cudaSuccess) {
|
||||
results_.back().disposition = Disposition::kFailed;
|
||||
@ -896,6 +1030,19 @@ bool GemmOperationProfiler::profile(
|
||||
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
|
||||
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
|
||||
|
||||
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
|
||||
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
|
||||
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
|
||||
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
|
||||
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||
}
|
||||
|
||||
results_.back().status = profile_cutlass_(
|
||||
results_.back().runtime,
|
||||
options,
|
||||
@ -921,6 +1068,15 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
|
||||
GpuTimer timer;
|
||||
|
||||
// initialize gemm underlying operation to handle parallel reduction
|
||||
library::Operation const * underlying_operation = operation;
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||
return Status::kErrorNotSupported;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Optional sleep to limit power consumption and thermals
|
||||
//
|
||||
@ -942,8 +1098,16 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||
|
||||
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
|
||||
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||
}
|
||||
|
||||
// Execute the CUTLASS operation
|
||||
status = operation->run(
|
||||
status = underlying_operation->run(
|
||||
&gemm_workspace_.arguments,
|
||||
host_workspace,
|
||||
device_workspace);
|
||||
@ -951,6 +1115,18 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run parallel reduction kernel for parallel split_k_mode
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
status = reduction_op_->run(
|
||||
&gemm_workspace_.reduction_arguments,
|
||||
gemm_workspace_.reduction_host_workspace.data(),
|
||||
nullptr);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
@ -977,7 +1153,15 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||
|
||||
status = operation->run(
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||
|
||||
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
|
||||
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||
}
|
||||
|
||||
status = underlying_operation->run(
|
||||
arguments,
|
||||
host_workspace,
|
||||
device_workspace);
|
||||
@ -985,6 +1169,18 @@ Status GemmOperationProfiler::profile_cutlass_(
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
|
||||
// Run parallel reduction kernel for parallel split_k_mode
|
||||
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||
status = reduction_op_->run(
|
||||
&gemm_workspace_.reduction_arguments,
|
||||
gemm_workspace_.reduction_host_workspace.data(),
|
||||
nullptr);
|
||||
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -1,5 +1,5 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||
* provided that the following conditions are met:
|
||||
@ -45,6 +45,7 @@
|
||||
#include "operation_profiler.h"
|
||||
#include "performance_result.h"
|
||||
#include "problem_space.h"
|
||||
#include "reduction_operation_profiler.h"
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -61,6 +62,7 @@ public:
|
||||
struct GemmProblem {
|
||||
|
||||
cutlass::library::GemmUniversalMode mode;
|
||||
cutlass::library::SplitKMode split_k_mode;
|
||||
int64_t m;
|
||||
int64_t n;
|
||||
int64_t k;
|
||||
@ -72,6 +74,12 @@ public:
|
||||
int split_k_slices;
|
||||
int batch_count;
|
||||
|
||||
// gemm with parallel interleaved reduction
|
||||
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||
std::vector<uint8_t> alpha_one;
|
||||
std::vector<uint8_t> beta_zero;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -121,6 +129,13 @@ public:
|
||||
/// Buffer used for the operations' device workspace
|
||||
DeviceAllocation device_workspace;
|
||||
|
||||
/// Library configuration and arguments for reduction operator
|
||||
library::ReductionConfiguration reduction_configuration;
|
||||
library::ReductionArguments reduction_arguments;
|
||||
|
||||
/// Buffer used for the cutlass reduction operations' host workspace
|
||||
std::vector<uint8_t> reduction_host_workspace;
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
@ -141,6 +156,8 @@ protected:
|
||||
/// Device memory allocations
|
||||
GemmWorkspace gemm_workspace_;
|
||||
|
||||
/// CUTLASS parallel reduction operation to follow this* gemm operation
|
||||
library::Operation const *reduction_op_;
|
||||
|
||||
public:
|
||||
//
|
||||
@ -231,6 +248,10 @@ protected:
|
||||
void *host_workspace,
|
||||
void *device_workspace);
|
||||
|
||||
/// Initialize reduction problem dimensions and library::Operation
|
||||
bool initialize_reduction_configuration_(
|
||||
library::Operation const *operation,
|
||||
ProblemSpace::Problem const &problem);
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
Reference in New Issue
Block a user