Support parallel split K mode for porfiling (#277)
* Support parallel split K mode for porfiling Signed-off-by: Peter Han <fujun.han@iluvatar.ai> * Parallel Split K support 1. find gemm kernel by preference key 2. switch m n for redution kernel Signed-off-by: Peter Han <fujun.han@iluvatar.ai> * parallel splitk for fp16 gemm * add one missing file Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
c3353add63
commit
1e4703cbab
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -338,6 +338,9 @@ using HandlePtr = std::unique_ptr<Handle>;
|
|||||||
/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace
|
/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace
|
||||||
Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation);
|
Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation);
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace
|
||||||
|
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation);
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace library
|
} // namespace library
|
||||||
} // namespace cutlass
|
} // namespace cutlass
|
||||||
|
@ -590,7 +590,8 @@ public:
|
|||||||
void const *configuration) const = 0;
|
void const *configuration) const = 0;
|
||||||
|
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration) const = 0;
|
void const *configuration,
|
||||||
|
void const *arguments = nullptr) const = 0;
|
||||||
|
|
||||||
virtual Status initialize(
|
virtual Status initialize(
|
||||||
void const *configuration,
|
void const *configuration,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -272,7 +272,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -266,7 +266,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -242,7 +242,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
@ -443,7 +444,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
@ -569,7 +571,7 @@ protected:
|
|||||||
operator_args.ldb = (configuration->ldb);
|
operator_args.ldb = (configuration->ldb);
|
||||||
operator_args.ldc = (configuration->ldc);
|
operator_args.ldc = (configuration->ldc);
|
||||||
operator_args.ldd = (configuration->ldd);
|
operator_args.ldd = (configuration->ldd);
|
||||||
|
|
||||||
return Status::kSuccess;
|
return Status::kSuccess;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -649,7 +651,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
@ -661,6 +664,14 @@ public:
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
status = update_arguments_(
|
||||||
|
args,
|
||||||
|
static_cast<GemmUniversalArguments const *>(arguments_ptr));
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
uint64_t size = Operator::get_workspace_size(args);
|
uint64_t size = Operator::get_workspace_size(args);
|
||||||
|
|
||||||
return size;
|
return size;
|
||||||
@ -855,7 +866,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
@ -1055,7 +1067,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -1098,6 +1098,59 @@ Operation const* find_conv_operation_for_parallel_reduction(Operation const *ope
|
|||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
/// Finds gemm operation instances with Gemm::ElementC = Reduction::ElementWorkspace
|
||||||
|
Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation) {
|
||||||
|
|
||||||
|
GemmDescription const &gemm_desc =
|
||||||
|
static_cast<GemmDescription const &>(operation->description());
|
||||||
|
|
||||||
|
// if the curren gemm operation accumulator and output data type match return operation
|
||||||
|
if(gemm_desc.tile_description.math_instruction.element_accumulator == gemm_desc.C.element) {
|
||||||
|
return operation;
|
||||||
|
}
|
||||||
|
|
||||||
|
// find gemm operation to match gemm output and reduction workspace data type
|
||||||
|
GemmFunctionalKey key(
|
||||||
|
library::Provider::kCUTLASS,
|
||||||
|
gemm_desc.gemm_kind,
|
||||||
|
gemm_desc.tile_description.math_instruction.element_accumulator,
|
||||||
|
gemm_desc.element_epilogue,
|
||||||
|
gemm_desc.A.element,
|
||||||
|
gemm_desc.A.layout,
|
||||||
|
gemm_desc.transform_A,
|
||||||
|
gemm_desc.B.element,
|
||||||
|
gemm_desc.B.layout,
|
||||||
|
gemm_desc.transform_B,
|
||||||
|
gemm_desc.tile_description.math_instruction.element_accumulator);
|
||||||
|
|
||||||
|
// gemm operation table
|
||||||
|
auto gemm_operations = Singleton::get().operation_table.gemm_operations;
|
||||||
|
|
||||||
|
// find ConvFunctionalKey in gemm operation table
|
||||||
|
auto operators_it = gemm_operations.find(key);
|
||||||
|
|
||||||
|
if (operators_it == gemm_operations.end()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (operators_it->second.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// A and B uses the same alignment in the generator.py
|
||||||
|
int alignment = gemm_desc.A.alignment;
|
||||||
|
|
||||||
|
// gemm operation for same compute capability and iterator algorithm
|
||||||
|
GemmPreferenceKey preference_key(
|
||||||
|
gemm_desc.tile_description.minimum_compute_capability,
|
||||||
|
alignment);
|
||||||
|
|
||||||
|
return find_gemm_operation(operators_it, preference_key);
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
} // namespace library
|
} // namespace library
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -46,7 +46,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) {
|
|||||||
|
|
||||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementCompute
|
ElementCompute
|
||||||
>;
|
>;
|
||||||
@ -81,7 +81,7 @@ void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) {
|
|||||||
|
|
||||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementCompute
|
ElementCompute
|
||||||
>;
|
>;
|
||||||
@ -115,7 +115,7 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
|||||||
|
|
||||||
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
128 / cutlass::sizeof_bits<ElementOutput>::value,
|
128 / cutlass::sizeof_bits<ElementWorkspace>::value,
|
||||||
ElementAccumulator,
|
ElementAccumulator,
|
||||||
ElementCompute
|
ElementCompute
|
||||||
>;
|
>;
|
||||||
@ -140,6 +140,5 @@ void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest)
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -30,6 +30,7 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "cutlass/cutlass.h"
|
#include "cutlass/cutlass.h"
|
||||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||||
|
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
|
||||||
#include "cutlass/reduction/thread/reduction_operators.h"
|
#include "cutlass/reduction/thread/reduction_operators.h"
|
||||||
#include "cutlass/reduction/device/reduce_split_k.h"
|
#include "cutlass/reduction/device/reduce_split_k.h"
|
||||||
|
|
||||||
@ -180,7 +181,8 @@ public:
|
|||||||
|
|
||||||
/// Gets the device-side workspace
|
/// Gets the device-side workspace
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration_ptr) const {
|
void const *configuration_ptr,
|
||||||
|
void const *arguments_ptr = nullptr) const {
|
||||||
|
|
||||||
OperatorArguments args;
|
OperatorArguments args;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -403,7 +403,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration) const {
|
void const *configuration,
|
||||||
|
void const *arguments = nullptr) const {
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -161,7 +161,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t get_device_workspace_size(
|
virtual uint64_t get_device_workspace_size(
|
||||||
void const *configuration) const {
|
void const *configuration,
|
||||||
|
void const *arguments = nullptr) const {
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -37,6 +37,7 @@
|
|||||||
#include "gemm_operation_profiler.h"
|
#include "gemm_operation_profiler.h"
|
||||||
#include "gpu_timer.h"
|
#include "gpu_timer.h"
|
||||||
|
|
||||||
|
#include "cutlass/library/singleton.h"
|
||||||
#include "cutlass/library/library.h"
|
#include "cutlass/library/library.h"
|
||||||
#include "cutlass/library/handle.h"
|
#include "cutlass/library/handle.h"
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ GemmOperationProfiler::GemmOperationProfiler(Options const &options):
|
|||||||
library::OperationKind::kGemm,
|
library::OperationKind::kGemm,
|
||||||
{
|
{
|
||||||
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
{ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (gemm, batched, array, universal, planar_complex, planar_complex_array)"},
|
||||||
|
{ArgumentTypeID::kEnumerated, {"split_k_mode"}, "Variant of split K mode(serial, parallel)"},
|
||||||
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"},
|
||||||
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"},
|
||||||
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
{ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"},
|
||||||
@ -100,6 +102,9 @@ void GemmOperationProfiler::print_examples(std::ostream &out) const {
|
|||||||
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
|
<< "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n"
|
||||||
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
|
<< " $ cutlass_profiler --operation=Gemm --A=f16:column --B=*:row\n\n"
|
||||||
|
|
||||||
|
<< "Profile a particular problem size with split K and paralell reduction:\n"
|
||||||
|
<< " $ cutlass_profiler --operation=Gemm --split_k_mode=parallel --split_k_slices=2 --m=1024 --n=1024 --k=128\n\n"
|
||||||
|
|
||||||
<< "Using various input value distribution:\n"
|
<< "Using various input value distribution:\n"
|
||||||
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
|
<< " $ cutlass_profiler --operation=Gemm --dist=uniform,min:0,max:3\n"
|
||||||
<< " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n"
|
<< " $ cutlass_profiler --operation=Gemm --dist=gaussian,mean:0,stddev:3\n"
|
||||||
@ -155,8 +160,17 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
|||||||
// default value
|
// default value
|
||||||
this->k = 1024;
|
this->k = 1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!arg_as_SplitKModeID(this->split_k_mode, "split_k_mode", problem_space, problem)) {
|
||||||
|
// defualt value
|
||||||
|
this->split_k_mode = library::SplitKMode::kSerial;
|
||||||
|
}
|
||||||
|
|
||||||
this->mode = library::GemmUniversalMode::kGemm;
|
this->mode = library::GemmUniversalMode::kGemm;
|
||||||
|
if(this->split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
this->mode = library::GemmUniversalMode::kGemmSplitKParallel;
|
||||||
|
}
|
||||||
|
|
||||||
if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) {
|
if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) {
|
||||||
// default value
|
// default value
|
||||||
this->split_k_slices = 1;
|
this->split_k_slices = 1;
|
||||||
@ -165,8 +179,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
|||||||
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
|
if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) {
|
||||||
// default value
|
// default value
|
||||||
this->batch_count = 1;
|
this->batch_count = 1;
|
||||||
}
|
} else if (this->batch_count > 1) {
|
||||||
else if (this->batch_count > 1) {
|
|
||||||
this->mode = library::GemmUniversalMode::kBatched;
|
this->mode = library::GemmUniversalMode::kBatched;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +223,7 @@ Status GemmOperationProfiler::GemmProblem::parse(
|
|||||||
return Status::kErrorInternal;
|
return Status::kErrorInternal;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this->lda = DeviceAllocation::get_packed_layout(
|
this->lda = DeviceAllocation::get_packed_layout(
|
||||||
operation_desc.A.layout, {int(this->m), int(this->k)}).front();
|
operation_desc.A.layout, {int(this->m), int(this->k)}).front();
|
||||||
|
|
||||||
@ -279,6 +292,8 @@ void GemmOperationProfiler::GemmProblem::initialize_result(
|
|||||||
|
|
||||||
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind));
|
||||||
|
|
||||||
|
set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode));
|
||||||
|
|
||||||
set_argument(result, "A", problem_space,
|
set_argument(result, "A", problem_space,
|
||||||
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout));
|
||||||
|
|
||||||
@ -321,7 +336,7 @@ Status GemmOperationProfiler::initialize_configuration(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status status = problem_.parse(operation_desc, problem_space, problem);
|
Status status = problem_.parse(operation_desc, problem_space, problem);
|
||||||
|
|
||||||
if (status != Status::kSuccess) {
|
if (status != Status::kSuccess) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
@ -350,6 +365,13 @@ Status GemmOperationProfiler::initialize_configuration(
|
|||||||
gemm_workspace_.arguments.beta = problem_.beta.data();
|
gemm_workspace_.arguments.beta = problem_.beta.data();
|
||||||
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||||
|
|
||||||
|
// initialize reduction operation for parallel splitKMode
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
if (!initialize_reduction_configuration_(operation, problem)) {
|
||||||
|
return Status::kErrorInternal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
initialize_result_(this->model_result_, options, operation_desc, problem_space);
|
||||||
|
|
||||||
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
|
return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments);
|
||||||
@ -366,7 +388,7 @@ void GemmOperationProfiler::initialize_result_(
|
|||||||
result.disposition = Disposition::kNotRun;
|
result.disposition = Disposition::kNotRun;
|
||||||
result.status = Status::kSuccess;
|
result.status = Status::kSuccess;
|
||||||
result.operation_name = operation_desc.name;
|
result.operation_name = operation_desc.name;
|
||||||
|
|
||||||
problem_.initialize_result(result, operation_desc, problem_space);
|
problem_.initialize_result(result, operation_desc, problem_space);
|
||||||
|
|
||||||
OperationProfiler::initialize_result_(result, operation_desc, problem_space);
|
OperationProfiler::initialize_result_(result, operation_desc, problem_space);
|
||||||
@ -377,6 +399,51 @@ void GemmOperationProfiler::initialize_result_(
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialize redution problem dimentions and library::Operation
|
||||||
|
bool GemmOperationProfiler::initialize_reduction_configuration_(
|
||||||
|
library::Operation const *operation,
|
||||||
|
ProblemSpace::Problem const &problem) {
|
||||||
|
library::GemmDescription const &gemm_desc =
|
||||||
|
static_cast<library::GemmDescription const&>(operation->description());
|
||||||
|
|
||||||
|
if (!cast_from_double(problem_.alpha_one, gemm_desc.element_epilogue, 1)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!cast_from_double(problem_.beta_zero, gemm_desc.element_epilogue, 0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// initialize library::ReductionConfiguration
|
||||||
|
gemm_workspace_.reduction_configuration.problem_size = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn();
|
||||||
|
gemm_workspace_.reduction_configuration.partitions = int(problem_.split_k_slices);
|
||||||
|
gemm_workspace_.reduction_configuration.partition_stride = gemm::GemmCoord(int(problem_.n), int(problem_.m), int(problem_.k)).mn().product();
|
||||||
|
gemm_workspace_.reduction_configuration.ldw = problem_.ldc;
|
||||||
|
gemm_workspace_.reduction_configuration.lds = problem_.ldc;
|
||||||
|
gemm_workspace_.reduction_configuration.ldd = problem_.ldc;
|
||||||
|
|
||||||
|
// find reduction operation
|
||||||
|
library::ReductionFunctionalKey reduction_key(
|
||||||
|
library::Provider::kCUTLASS,
|
||||||
|
gemm_desc.tile_description.math_instruction.element_accumulator, // element workspace
|
||||||
|
gemm_desc.tile_description.math_instruction.element_accumulator, // element accumulator
|
||||||
|
gemm_desc.C.element, // element output
|
||||||
|
gemm_desc.element_epilogue // element coumpute
|
||||||
|
);
|
||||||
|
|
||||||
|
auto reduction_it = library::Singleton::get().operation_table.reduction_operations.find(reduction_key);
|
||||||
|
|
||||||
|
if (reduction_it == library::Singleton::get().operation_table.reduction_operations.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// initialize reduction operation required for parallel split-k operator
|
||||||
|
reduction_op_ = reduction_it->second;
|
||||||
|
|
||||||
|
// reduction operation found and initialized
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
/// Initializes workspace
|
/// Initializes workspace
|
||||||
Status GemmOperationProfiler::initialize_workspace(
|
Status GemmOperationProfiler::initialize_workspace(
|
||||||
Options const &options,
|
Options const &options,
|
||||||
@ -385,7 +452,15 @@ Status GemmOperationProfiler::initialize_workspace(
|
|||||||
library::Operation const *operation,
|
library::Operation const *operation,
|
||||||
ProblemSpace const &problem_space,
|
ProblemSpace const &problem_space,
|
||||||
ProblemSpace::Problem const &problem) {
|
ProblemSpace::Problem const &problem) {
|
||||||
|
|
||||||
|
library::Operation const* underlying_operation = operation;
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||||
|
return Status::kErrorNotSupported;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
library::GemmDescription const &operation_desc =
|
library::GemmDescription const &operation_desc =
|
||||||
static_cast<library::GemmDescription const &>(operation->description());
|
static_cast<library::GemmDescription const &>(operation->description());
|
||||||
|
|
||||||
@ -455,8 +530,12 @@ Status GemmOperationProfiler::initialize_workspace(
|
|||||||
);
|
);
|
||||||
|
|
||||||
gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data());
|
gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data());
|
||||||
}
|
|
||||||
|
|
||||||
|
gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride();
|
||||||
|
gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride();
|
||||||
|
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||||
|
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Initialize the CUTLASS operation
|
// Initialize the CUTLASS operation
|
||||||
@ -467,16 +546,35 @@ Status GemmOperationProfiler::initialize_workspace(
|
|||||||
|
|
||||||
if (options.execution_mode != ExecutionMode::kDryRun) {
|
if (options.execution_mode != ExecutionMode::kDryRun) {
|
||||||
|
|
||||||
uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration);
|
uint64_t workspace_size = underlying_operation->get_host_workspace_size(&gemm_workspace_.configuration);
|
||||||
gemm_workspace_.host_workspace.resize(workspace_size, 0);
|
gemm_workspace_.host_workspace.resize(workspace_size, 0);
|
||||||
|
|
||||||
workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration);
|
workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration,
|
||||||
|
&gemm_workspace_.arguments);
|
||||||
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
|
gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size);
|
||||||
|
|
||||||
status = operation->initialize(
|
status = underlying_operation->initialize(
|
||||||
&gemm_workspace_.configuration,
|
&gemm_workspace_.configuration,
|
||||||
gemm_workspace_.host_workspace.data(),
|
gemm_workspace_.host_workspace.data(),
|
||||||
gemm_workspace_.device_workspace.data());
|
gemm_workspace_.device_workspace.data());
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
workspace_size = reduction_op_->get_host_workspace_size(&gemm_workspace_.reduction_configuration);
|
||||||
|
gemm_workspace_.reduction_host_workspace.resize(workspace_size, 0);
|
||||||
|
|
||||||
|
status = reduction_op_->initialize(
|
||||||
|
&gemm_workspace_.reduction_configuration,
|
||||||
|
gemm_workspace_.reduction_host_workspace.data(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -527,11 +625,34 @@ bool GemmOperationProfiler::verify_cutlass(
|
|||||||
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||||
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
|
||||||
|
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
|
||||||
|
|
||||||
|
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
|
||||||
|
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
|
||||||
|
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
|
||||||
|
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
|
||||||
|
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Run the CUTLASS operation
|
// Run the CUTLASS operation
|
||||||
//
|
//
|
||||||
|
|
||||||
results_.back().status = operation->run(
|
// initialize gemm underlying operation to handle parallel reduction
|
||||||
|
library::Operation const * underlying_operation = operation;
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||||
|
results_.back().disposition = Disposition::kFailed;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results_.back().status = underlying_operation->run(
|
||||||
&gemm_workspace_.arguments,
|
&gemm_workspace_.arguments,
|
||||||
gemm_workspace_.host_workspace.data(),
|
gemm_workspace_.host_workspace.data(),
|
||||||
gemm_workspace_.device_workspace.data());
|
gemm_workspace_.device_workspace.data());
|
||||||
@ -541,6 +662,19 @@ bool GemmOperationProfiler::verify_cutlass(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run parallel reduction kernel for parallel split_k_mode
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
results_.back().status = reduction_op_->run(
|
||||||
|
&gemm_workspace_.reduction_arguments,
|
||||||
|
gemm_workspace_.reduction_host_workspace.data(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
if (results_.back().status != Status::kSuccess) {
|
||||||
|
results_.back().disposition = Disposition::kFailed;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
cudaError_t result = cudaDeviceSynchronize();
|
cudaError_t result = cudaDeviceSynchronize();
|
||||||
if (result != cudaSuccess) {
|
if (result != cudaSuccess) {
|
||||||
results_.back().disposition = Disposition::kFailed;
|
results_.back().disposition = Disposition::kFailed;
|
||||||
@ -896,6 +1030,19 @@ bool GemmOperationProfiler::profile(
|
|||||||
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride();
|
||||||
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride();
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.arguments.alpha = problem_.alpha_one.data();
|
||||||
|
gemm_workspace_.arguments.beta = problem_.beta_zero.data();
|
||||||
|
|
||||||
|
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->data();
|
||||||
|
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->data();
|
||||||
|
gemm_workspace_.reduction_arguments.alpha = problem_.alpha.data();
|
||||||
|
gemm_workspace_.reduction_arguments.beta = problem_.beta.data();
|
||||||
|
gemm_workspace_.reduction_arguments.pointer_mode = library::ScalarPointerMode::kHost;
|
||||||
|
}
|
||||||
|
|
||||||
results_.back().status = profile_cutlass_(
|
results_.back().status = profile_cutlass_(
|
||||||
results_.back().runtime,
|
results_.back().runtime,
|
||||||
options,
|
options,
|
||||||
@ -921,6 +1068,15 @@ Status GemmOperationProfiler::profile_cutlass_(
|
|||||||
|
|
||||||
GpuTimer timer;
|
GpuTimer timer;
|
||||||
|
|
||||||
|
// initialize gemm underlying operation to handle parallel reduction
|
||||||
|
library::Operation const * underlying_operation = operation;
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
if (!(underlying_operation = library::find_gemm_operation_for_parallel_reduction(operation))) {
|
||||||
|
return Status::kErrorNotSupported;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// Optional sleep to limit power consumption and thermals
|
// Optional sleep to limit power consumption and thermals
|
||||||
//
|
//
|
||||||
@ -942,8 +1098,16 @@ Status GemmOperationProfiler::profile_cutlass_(
|
|||||||
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
||||||
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||||
|
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||||
|
|
||||||
|
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
|
||||||
|
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||||
|
}
|
||||||
|
|
||||||
// Execute the CUTLASS operation
|
// Execute the CUTLASS operation
|
||||||
status = operation->run(
|
status = underlying_operation->run(
|
||||||
&gemm_workspace_.arguments,
|
&gemm_workspace_.arguments,
|
||||||
host_workspace,
|
host_workspace,
|
||||||
device_workspace);
|
device_workspace);
|
||||||
@ -951,6 +1115,18 @@ Status GemmOperationProfiler::profile_cutlass_(
|
|||||||
if (status != Status::kSuccess) {
|
if (status != Status::kSuccess) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run parallel reduction kernel for parallel split_k_mode
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
status = reduction_op_->run(
|
||||||
|
&gemm_workspace_.reduction_arguments,
|
||||||
|
gemm_workspace_.reduction_host_workspace.data(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -977,7 +1153,15 @@ Status GemmOperationProfiler::profile_cutlass_(
|
|||||||
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx);
|
||||||
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||||
|
|
||||||
status = operation->run(
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
gemm_workspace_.arguments.D = gemm_workspace_.device_workspace.data();
|
||||||
|
|
||||||
|
gemm_workspace_.reduction_arguments.workspace = gemm_workspace_.device_workspace.data();
|
||||||
|
gemm_workspace_.reduction_arguments.source = gemm_workspace_.C->batch_data(problem_idx);
|
||||||
|
gemm_workspace_.reduction_arguments.destination = gemm_workspace_.Computed->batch_data(problem_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
status = underlying_operation->run(
|
||||||
arguments,
|
arguments,
|
||||||
host_workspace,
|
host_workspace,
|
||||||
device_workspace);
|
device_workspace);
|
||||||
@ -985,6 +1169,18 @@ Status GemmOperationProfiler::profile_cutlass_(
|
|||||||
if (status != Status::kSuccess) {
|
if (status != Status::kSuccess) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Run parallel reduction kernel for parallel split_k_mode
|
||||||
|
if (problem_.split_k_mode == library::SplitKMode::kParallel) {
|
||||||
|
status = reduction_op_->run(
|
||||||
|
&gemm_workspace_.reduction_arguments,
|
||||||
|
gemm_workspace_.reduction_host_workspace.data(),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
if (status != Status::kSuccess) {
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/***************************************************************************************************
|
/***************************************************************************************************
|
||||||
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
|
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
|
||||||
*
|
*
|
||||||
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||||||
* provided that the following conditions are met:
|
* provided that the following conditions are met:
|
||||||
@ -45,6 +45,7 @@
|
|||||||
#include "operation_profiler.h"
|
#include "operation_profiler.h"
|
||||||
#include "performance_result.h"
|
#include "performance_result.h"
|
||||||
#include "problem_space.h"
|
#include "problem_space.h"
|
||||||
|
#include "reduction_operation_profiler.h"
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ public:
|
|||||||
struct GemmProblem {
|
struct GemmProblem {
|
||||||
|
|
||||||
cutlass::library::GemmUniversalMode mode;
|
cutlass::library::GemmUniversalMode mode;
|
||||||
|
cutlass::library::SplitKMode split_k_mode;
|
||||||
int64_t m;
|
int64_t m;
|
||||||
int64_t n;
|
int64_t n;
|
||||||
int64_t k;
|
int64_t k;
|
||||||
@ -72,6 +74,12 @@ public:
|
|||||||
int split_k_slices;
|
int split_k_slices;
|
||||||
int batch_count;
|
int batch_count;
|
||||||
|
|
||||||
|
// gemm with parallel interleaved reduction
|
||||||
|
// gemm epilogue (alpha, beta) = (1.0, 0.0)
|
||||||
|
// reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta)
|
||||||
|
std::vector<uint8_t> alpha_one;
|
||||||
|
std::vector<uint8_t> beta_zero;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
@ -121,6 +129,13 @@ public:
|
|||||||
/// Buffer used for the operations' device workspace
|
/// Buffer used for the operations' device workspace
|
||||||
DeviceAllocation device_workspace;
|
DeviceAllocation device_workspace;
|
||||||
|
|
||||||
|
/// Library configuration and arguments for reduction operator
|
||||||
|
library::ReductionConfiguration reduction_configuration;
|
||||||
|
library::ReductionArguments reduction_arguments;
|
||||||
|
|
||||||
|
/// Buffer used for the cutlass reduction operations' host workspace
|
||||||
|
std::vector<uint8_t> reduction_host_workspace;
|
||||||
|
|
||||||
//
|
//
|
||||||
// Methods
|
// Methods
|
||||||
//
|
//
|
||||||
@ -141,6 +156,8 @@ protected:
|
|||||||
/// Device memory allocations
|
/// Device memory allocations
|
||||||
GemmWorkspace gemm_workspace_;
|
GemmWorkspace gemm_workspace_;
|
||||||
|
|
||||||
|
/// CUTLASS parallel reduction operation to follow this* gemm operation
|
||||||
|
library::Operation const *reduction_op_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
//
|
//
|
||||||
@ -231,6 +248,10 @@ protected:
|
|||||||
void *host_workspace,
|
void *host_workspace,
|
||||||
void *device_workspace);
|
void *device_workspace);
|
||||||
|
|
||||||
|
/// Initialize reduction problem dimensions and library::Operation
|
||||||
|
bool initialize_reduction_configuration_(
|
||||||
|
library::Operation const *operation,
|
||||||
|
ProblemSpace::Problem const &problem);
|
||||||
};
|
};
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
Loading…
Reference in New Issue
Block a user