Support parallel split K mode for porfiling (#277)

* Support parallel split K mode for porfiling

Signed-off-by: Peter Han <fujun.han@iluvatar.ai>

* Parallel Split K support

  1. find gemm kernel by preference key
  2. switch m n for redution kernel

Signed-off-by: Peter Han <fujun.han@iluvatar.ai>

* parallel splitk for fp16 gemm

* add one missing file

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
Fujun Han 2022-01-27 23:37:37 +08:00 committed by GitHub
parent c3353add63
commit 1e4703cbab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 332 additions and 40 deletions

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -338,6 +338,9 @@ using HandlePtr = std::unique_ptr<Handle>;
/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace
Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation);
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation);
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass

View File

@ -590,7 +590,8 @@ public:
void const *configuration) const = 0;
virtual uint64_t get_device_workspace_size(
void const *configuration) const = 0;
void const *configuration,
void const *arguments = nullptr) const = 0;
virtual Status initialize(
void const *configuration,

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -272,7 +272,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -266,7 +266,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -242,7 +242,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;
@ -443,7 +444,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;
@ -569,7 +571,7 @@ protected:
operator_args.ldb = (configuration->ldb);
operator_args.ldc = (configuration->ldc);
operator_args.ldd = (configuration->ldd);
return Status::kSuccess;
}
@ -649,7 +651,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr) const {
OperatorArguments args;
@ -661,6 +664,14 @@ public:
return 0;
}
status = update_arguments_(
args,
static_cast<GemmUniversalArguments const *>(arguments_ptr));
if (status != Status::kSuccess) {
return 0;
}
uint64_t size = Operator::get_workspace_size(args);
return size;
@ -855,7 +866,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;
@ -1055,7 +1067,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -1098,6 +1098,59 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope
return nullptr;
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) {
GemmDescription const &gemm_desc =
static_cast<GemmDescription const &>(operation->description());
// if the curren gemm operation accumulator and output data type match return operation
if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) {
return operation;
}
// find gemm operation to match gemm output and reduction workspace data type
GemmFunctionalKey key(
library::Provider::kCUTLASS,
gemm_desc.gemm_kind,
gemm_desc.tile_description.math_instruction.element_accumulator,
gemm_desc.element_epilogue,
gemm_desc.A.element,
gemm_desc.A.layout,
gemm_desc.transform_A,
gemm_desc.B.element,
gemm_desc.B.layout,
gemm_desc.transform_B,
gemm_desc.tile_description.math_instruction.element_accumulator);
// gemm operation table
auto gemm_operations = Singleton::get().operation_table.gemm_operations;
// find ConvFunctionalKey in gemm operation table
auto operators_it = gemm_operations.find(key);
if (operators_it == gemm_operations.end()) {
return nullptr;
}
if (operators_it->second.empty()) {
return nullptr;
}
// A and B uses the same alignment in the generator.py
int alignment = gemm_desc.A.alignment;
// gemm operation for same compute capability and iterator algorithm
GemmPreferenceKey preference_key(
gemm_desc.tile_description.minimum_compute_capability,
alignment);
return find_gemm_operation(operators_it, preference_key);
}
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -46,7 +46,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;
@ -81,7 +81,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) {
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;
@ -115,7 +115,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput,
128 / cutlass::sizeof_bits<ElementOutput>::value,
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
ElementAccumulator,
ElementCompute
>;
@ -140,6 +140,5 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
));
}
}
}

View File

@ -30,6 +30,7 @@
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#include "cutlass/reduction/device/reduce_split_k.h"
@ -180,7 +181,8 @@ public:
/// Gets the device-side workspace
virtual uint64_t get_device_workspace_size(
void const *configuration_ptr) const {
void const *configuration_ptr,
void const *arguments_ptr = nullptr) const {
OperatorArguments args;

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -403,7 +403,8 @@ public:
}
virtual uint64_t get_device_workspace_size(
void const *configuration) const {
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -161,7 +161,8 @@ public:
}
virtual uint64_t get_device_workspace_size(
void const *configuration) const {
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -37,6 +37,7 @@
#include "gemm_operation_profiler.h"
#include "gpu_timer.h"
#include "cutlass/library/singleton.h"
#include "cutlass/library/library.h"
#include "cutlass/library/handle.h"
@ -55,6 +56,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
library::OperationKind::kGemm,
{
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
@ -100,6 +102,9 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
<< "Profile a particular problem size with split K and paralell reduction:\n"
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n"
<< "Using various input value distribution:\n"
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
<< " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n"
@ -155,8 +160,17 @@ Status GemmOperationProfiler::GemmProblem::parse(
// default value
this->k = 1024;
}
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
// defualt value
this->split_k_mode = library::SplitKMode::kSerial;
}
this->mode = library::GemmUniversalMode::kGemm;
if(this->split_k_mode == library::SplitKMode::kParallel) {
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
}
if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) {
// default value
this->split_k_slices = 1;
@ -165,8 +179,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
// default value
this->batch_count = 1;
}
else if (this->batch_count > 1) {
} else if (this->batch_count > 1) {
this->mode = library::GemmUniversalMode::kBatched;
}
@ -210,7 +223,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
return Status::kErrorInternal;
}
}
this->lda = DeviceAllocation::get_packed_layout(
operation_desc.A.layout, {int(this->m), int(this->k)}).front();
@ -279,6 +292,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
set_argument(result, "A", problem_space,
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
@ -321,7 +336,7 @@ Status GemmOperationProfiler::initialize_configuration(
}
Status status = problem_.parse(operation_desc, problem_space, problem);
if (status != Status::kSuccess) {
return status;
}
@ -350,6 +365,13 @@ Status GemmOperationProfiler::initialize_configuration(
gemm_workspace_.arguments.beta = problem_.beta.data();
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
// initialize reduction operation for parallel splitKMode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!initialize_reduction_configuration_(operation, problem)) {
return Status::kErrorInternal;
}
}
initialize_result_(this->model_result_, options, operation_desc, problem_space);
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
@ -366,7 +388,7 @@ void GemmOperationProfiler::initialize_result_(
result.disposition = Disposition::kNotRun;
result.status = Status::kSuccess;
result.operation_name = operation_desc.name;
problem_.initialize_result(result, operation_desc, problem_space);
OperationProfiler::initialize_result_(result, operation_desc, problem_space);
@ -377,6 +399,51 @@ void GemmOperationProfiler::initialize_result_(
}
/// Initialize redution problem dimentions and library::Operation
bool GemmOperationProfiler::initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem) {
library::GemmDescription const &gemm_desc =
static_cast<library::GemmDescription const&>(operation->description());
if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) {
return false;
}
if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) {
return false;
}
/// initialize library::ReductionConfiguration
gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn();
gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices);
gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product();
gemm_workspace_.reduction_configuration.ldw = problem_.ldc;
gemm_workspace_.reduction_configuration.lds = problem_.ldc;
gemm_workspace_.reduction_configuration.ldd = problem_.ldc;
// find reduction operation
library::ReductionFunctionalKey reduction_key(
library::Provider::kCUTLASS,
gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace
gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator
gemm_desc.C.element, // element output
gemm_desc.element_epilogue // element coumpute
);
auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);
if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
return false;
}
// initialize reduction operation required for parallel split-k operator
reduction_op_ = reduction_it->second;
// reduction operation found and initialized
return true;
}
/// Initializes workspace
Status GemmOperationProfiler::initialize_workspace(
Options const &options,
@ -385,7 +452,15 @@ Status GemmOperationProfiler::initialize_workspace(
library::Operation const *operation,
ProblemSpace const &problem_space,
ProblemSpace::Problem const &problem) {
library::Operation const* underlying_operation = operation;
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
return Status::kErrorNotSupported;
}
}
library::GemmDescription const &operation_desc =
static_cast<library::GemmDescription const &>(operation->description());
@ -455,8 +530,12 @@ Status GemmOperationProfiler::initialize_workspace(
);
gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data());
}
gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride();
gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride();
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
}
//
// Initialize the CUTLASS operation
@ -467,16 +546,35 @@ Status GemmOperationProfiler::initialize_workspace(
if (options.execution_mode != ExecutionMode::kDryRun) {
uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration);
uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration);
gemm_workspace_.host_workspace.resize(workspace_size, 0);
workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration);
workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration,
&gemm_workspace_.arguments);
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
status = operation->initialize(
status = underlying_operation->initialize(
&gemm_workspace_.configuration,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
if (status != Status::kSuccess) {
return status;
}
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration);
gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0);
status = reduction_op_->initialize(
&gemm_workspace_.reduction_configuration,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);
if (status != Status::kSuccess) {
return status;
}
}
}
//
@ -527,11 +625,34 @@ bool GemmOperationProfiler::verify_cutlass(
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
}
//
// Run the CUTLASS operation
//
results_.back().status = operation->run(
// initialize gemm underlying operation to handle parallel reduction
library::Operation const * underlying_operation = operation;
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
results_.back().disposition = Disposition::kFailed;
return false;
}
}
results_.back().status = underlying_operation->run(
&gemm_workspace_.arguments,
gemm_workspace_.host_workspace.data(),
gemm_workspace_.device_workspace.data());
@ -541,6 +662,19 @@ bool GemmOperationProfiler::verify_cutlass(
return false;
}
// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
results_.back().status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);
if (results_.back().status != Status::kSuccess) {
results_.back().disposition = Disposition::kFailed;
return false;
}
}
cudaError_t result = cudaDeviceSynchronize();
if (result != cudaSuccess) {
results_.back().disposition = Disposition::kFailed;
@ -896,6 +1030,19 @@ bool GemmOperationProfiler::profile(
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
}
results_.back().status = profile_cutlass_(
results_.back().runtime,
options,
@ -921,6 +1068,15 @@ Status GemmOperationProfiler::profile_cutlass_(
GpuTimer timer;
// initialize gemm underlying operation to handle parallel reduction
library::Operation const * underlying_operation = operation;
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
return Status::kErrorNotSupported;
}
}
//
// Optional sleep to limit power consumption and thermals
//
@ -942,8 +1098,16 @@ Status GemmOperationProfiler::profile_cutlass_(
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
}
// Execute the CUTLASS operation
status = operation->run(
status = underlying_operation->run(
&gemm_workspace_.arguments,
host_workspace,
device_workspace);
@ -951,6 +1115,18 @@ Status GemmOperationProfiler::profile_cutlass_(
if (status != Status::kSuccess) {
return status;
}
// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);
if (status != Status::kSuccess) {
return status;
}
}
}
//
@ -977,7 +1153,15 @@ Status GemmOperationProfiler::profile_cutlass_(
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
status = operation->run(
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
}
status = underlying_operation->run(
arguments,
host_workspace,
device_workspace);
@ -985,6 +1169,18 @@ Status GemmOperationProfiler::profile_cutlass_(
if (status != Status::kSuccess) {
return status;
}
// Run parallel reduction kernel for parallel split_k_mode
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
status = reduction_op_->run(
&gemm_workspace_.reduction_arguments,
gemm_workspace_.reduction_host_workspace.data(),
nullptr);
if (status != Status::kSuccess) {
return status;
}
}
}
//

View File

@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
@ -45,6 +45,7 @@
#include "operation_profiler.h"
#include "performance_result.h"
#include "problem_space.h"
#include "reduction_operation_profiler.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
@ -61,6 +62,7 @@ public:
struct GemmProblem {
cutlass::library::GemmUniversalMode mode;
cutlass::library::SplitKMode split_k_mode;
int64_t m;
int64_t n;
int64_t k;
@ -72,6 +74,12 @@ public:
int split_k_slices;
int batch_count;
// gemm with parallel interleaved reduction
// gemm epilogue (alpha, beta) = (1.0, 0.0)
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
std::vector<uint8_t> alpha_one;
std::vector<uint8_t> beta_zero;
//
// Methods
//
@ -121,6 +129,13 @@ public:
/// Buffer used for the operations' device workspace
DeviceAllocation device_workspace;
/// Library configuration and arguments for reduction operator
library::ReductionConfiguration reduction_configuration;
library::ReductionArguments reduction_arguments;
/// Buffer used for the cutlass reduction operations' host workspace
std::vector<uint8_t> reduction_host_workspace;
//
// Methods
//
@ -141,6 +156,8 @@ protected:
/// Device memory allocations
GemmWorkspace gemm_workspace_;
/// CUTLASS parallel reduction operation to follow this* gemm operation
library::Operation const *reduction_op_;
public:
//
@ -231,6 +248,10 @@ protected:
void *host_workspace,
void *device_workspace);
/// Initialize reduction problem dimensions and library::Operation
bool initialize_reduction_configuration_(
library::Operation const *operation,
ProblemSpace::Problem const &problem);
};
/////////////////////////////////////////////////////////////////////////////////////////////////