cutlass/tools/profiler/src/cudnn_helpers.h
ANIKET SHIVAM 66d9cddc83
New updates for 2.11 (#775)
* New updates.

* Minor profiler updates

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-01-20 16:32:57 -05:00

591 lines
20 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Helper functions for mapping CUTLASS concepts to cuDNN.
*/
#pragma once
#if CUTLASS_ENABLE_CUDNN
#include <cuda_runtime.h>
#include <cudnn.h>
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/library/library.h"
#include "enumerated_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace profiler {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Converts a cuDNN status to cutlass::Status
Status get_cutlass_status(cudnnStatus_t cudnn_status);
/// Converts a cuDNN status to cutlass::profiler::Disposition
Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status);
/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception
Status checkCudnnErr(cudnnStatus_t cudnn_status);
/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration
bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode);
/// Maps a CUTLASS layout type to a cuDNN data type enumeration
bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout);
/// Maps a CUTLASS numeric type to a cuDNN data type enumeration
bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type);
/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type
bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc);
/// Returns a status if cudnn can satisfy a particular Conv2d description
Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration);
/// Returns a status if cudnn can satisfy a particular Conv3d description
Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration);
/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue)
float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src);
/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and
/// to destroy cudnnHandle_t on CudnnCreate object destruction.
/// Additionaly, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object
class CudnnCreate {
private:
cudnnHandle_t handle;
cudnnStatus_t status;
public:
CudnnCreate() {
status = cudnnCreate(&handle);
}
~CudnnCreate() {
cudnnDestroy(handle);
}
/// Implicit cast CudnnCreate object to cudnnHandle_t
operator cudnnHandle_t() const { return handle; }
/// returns cudnnStatus_t for handle creation
cudnnStatus_t get_cudnn_create_status() { return status; }
};
namespace detail {
/// Dispatcher to cudnn convolution operators
struct cudnnConvDispatcher {
//
// Data members
//
//library::Conv2dConfiguration configuration;
library::ConvArguments arguments;
library::ConvKind conv_kind;
// cudnn-specific data structures to fill cudnn API call arguments
// cudnn activation, filter, and output descriptors
cudnnTensorDescriptor_t activation_desc;
cudnnFilterDescriptor_t filter_desc;
cudnnTensorDescriptor_t output_desc;
cudnnConvolutionDescriptor_t conv_desc;
// cudnn datatypes
cudnnDataType_t data_type_activation;
cudnnDataType_t data_type_filter;
cudnnDataType_t data_type_output;
// cudnn layouts
cudnnTensorFormat_t layout_activation;
cudnnTensorFormat_t layout_filter;
cudnnTensorFormat_t layout_output;
// cudnn convolution mode
cudnnConvolutionMode_t conv_mode;
// cudnn math type (tensorop, tensorop with conversion, simt)
cudnnMathType_t math_type;
// cudnn compute data type
cudnnDataType_t compute_type;
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
float alpha;
float beta;
// cudnn workspace
size_t workspace_size_in_bytes = 0;
cutlass::device_memory::allocation<char> workspace;
// select cudnn's implicit gemm precomputed algorithm with tensor operations
static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
Status status;
//
// Methods
//
// TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfigration
// ctor for conv2d
cudnnConvDispatcher(
library::ConvDescription const &op_desc,
library::Conv2dConfiguration configuration,
library::ConvArguments arguments_,
cudnnHandle_t handle
):
//configuration(configuration_),
arguments(arguments_),
conv_kind(op_desc.conv_kind),
status(Status::kSuccess) {
bool good = true;
// Get cudnn datatype, layout, and convolution mode from library::ConvDescription
good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element));
good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element));
good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element));
good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));
good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode));
// Get cudnn mathtype (cudnnMathType_t)
good = (good && get_cudnn_mathtype(math_type, op_desc));
good = (good && get_cudnn_datatype(
compute_type,
op_desc.tile_description.math_instruction.element_accumulator));
// Check cutlass Conv2d description has equivalent operator in cudnn
if (!good) {
status = Status::kErrorNotSupported;
return;
}
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha);
beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta);
// Create convolution descriptor object
status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc));
// Configure convolution operator
std::vector<int> padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w};
std::vector<int> stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w};
std::vector<int> dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w};
status = get_cutlass_status(
cudnnSetConvolutionNdDescriptor(
conv_desc,
op_desc.conv_dim,
padding.data(),
stride.data(),
dilation.data(),
conv_mode,
compute_type
));
// Set groups
status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups));
// Create activation, filter, and output descriptor objects
status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc));
status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));
// Set activation, filter, and output descriptor
status = get_cutlass_status(
cudnnSetTensor4dDescriptor(
activation_desc,
layout_activation,
data_type_activation,
configuration.problem_size.N,
configuration.problem_size.C,
configuration.problem_size.H,
configuration.problem_size.W
));
status = get_cutlass_status(
cudnnSetFilter4dDescriptor(
filter_desc,
data_type_filter,
layout_filter,
configuration.problem_size.K,
configuration.problem_size.C / configuration.problem_size.groups,
configuration.problem_size.R,
configuration.problem_size.S
));
status = get_cutlass_status(
cudnnSetTensor4dDescriptor(
output_desc,
layout_output,
data_type_output,
configuration.problem_size.N,
configuration.problem_size.K,
configuration.problem_size.P,
configuration.problem_size.Q
));
// Set math instruction to tensor op
status = get_cutlass_status(
cudnnSetConvolutionMathType(conv_desc, math_type));
// Initialize workspace
switch (conv_kind) {
case library::ConvKind::kFprop:
status = get_cutlass_status(
cudnnGetConvolutionForwardWorkspaceSize(
handle,
activation_desc,
filter_desc,
conv_desc,
output_desc,
fprop_algo,
&workspace_size_in_bytes
)); break;
case library::ConvKind::kDgrad:
status = get_cutlass_status(
cudnnGetConvolutionBackwardDataWorkspaceSize(
handle,
filter_desc,
output_desc,
conv_desc,
activation_desc,
dgrad_algo,
&workspace_size_in_bytes
)); break;
case library::ConvKind::kWgrad:
status = get_cutlass_status(
cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle,
activation_desc,
output_desc,
conv_desc,
filter_desc,
wgrad_algo,
&workspace_size_in_bytes
)); break;
}
workspace = cutlass::device_memory::allocation<char>(workspace_size_in_bytes);
}
// ctor for conv3d
cudnnConvDispatcher(
library::ConvDescription const &op_desc,
library::Conv3dConfiguration configuration,
library::ConvArguments arguments_,
cudnnHandle_t handle
):
//configuration(configuration_),
arguments(arguments_),
conv_kind(op_desc.conv_kind),
status(Status::kSuccess) {
bool good = true;
// Get cudnn datatype, layout, and convolution mode from library::ConvDescription
good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element));
good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element));
good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element));
good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));
good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode));
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha);
beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta);
good = (good && get_cudnn_datatype(
compute_type,
op_desc.tile_description.math_instruction.element_accumulator));
// Check cutlass Conv2d description has equivalent operator in cudnn
if (!good) {
status = Status::kErrorNotSupported;
}
// Create convolution descriptor object
status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc));
// Configure convolution operator
std::vector<int> padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w};
std::vector<int> stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w};
std::vector<int> dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w};
status = get_cutlass_status(
cudnnSetConvolutionNdDescriptor(
conv_desc,
op_desc.conv_dim,
padding.data(),
stride.data(),
dilation.data(),
conv_mode,
compute_type
));
// Set groups
status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups));
// Create activation, filter, and output descriptor objects
status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc));
status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));
// Set activation descriptor
std::vector<int> activation_extent {
configuration.problem_size.N,
configuration.problem_size.C,
configuration.problem_size.D,
configuration.problem_size.H,
configuration.problem_size.W
};
std::vector<int> activation_stride {
configuration.layout_activations.stride()[3],
1,
configuration.layout_activations.stride()[2],
configuration.layout_activations.stride()[1],
configuration.layout_activations.stride()[0]
};
status = get_cutlass_status(
cudnnSetTensorNdDescriptor(
activation_desc,
data_type_activation,
op_desc.conv_dim + 2,
activation_extent.data(),
activation_stride.data()
));
// Set filter descriptor
std::vector<int> filter_extent {
configuration.problem_size.K,
configuration.problem_size.C,
configuration.problem_size.T,
configuration.problem_size.R,
configuration.problem_size.S
};
std::vector<int> filter_stride {
configuration.layout_filters.stride()[3],
1,
configuration.layout_filters.stride()[2],
configuration.layout_filters.stride()[1],
configuration.layout_filters.stride()[0]
};
status = get_cutlass_status(
cudnnSetFilterNdDescriptor(
filter_desc,
data_type_filter,
layout_filter,
op_desc.conv_dim + 2,
filter_extent.data()
));
// Set output descriptor
std::vector<int> output_extent {
configuration.problem_size.N,
configuration.problem_size.K,
configuration.problem_size.Z,
configuration.problem_size.P,
configuration.problem_size.Q
};
std::vector<int> output_stride {
configuration.layout_output.stride()[3],
1,
configuration.layout_output.stride()[2],
configuration.layout_output.stride()[1],
configuration.layout_output.stride()[0]
};
status = get_cutlass_status(
cudnnSetTensorNdDescriptor(
output_desc,
data_type_output,
op_desc.conv_dim + 2,
output_extent.data(),
output_stride.data()
));
// Set math instruction to tensor op
status = get_cutlass_status(
cudnnSetConvolutionMathType(conv_desc, math_type));
// Initialize workspace
switch (conv_kind) {
case library::ConvKind::kFprop:
status = get_cutlass_status(
cudnnGetConvolutionForwardWorkspaceSize(
handle,
activation_desc,
filter_desc,
conv_desc,
output_desc,
fprop_algo,
&workspace_size_in_bytes
)); break;
case library::ConvKind::kDgrad:
status = get_cutlass_status(
cudnnGetConvolutionBackwardDataWorkspaceSize(
handle,
filter_desc,
output_desc,
conv_desc,
activation_desc,
dgrad_algo,
&workspace_size_in_bytes
)); break;
case library::ConvKind::kWgrad:
status = get_cutlass_status(
cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle,
activation_desc,
output_desc,
conv_desc,
filter_desc,
wgrad_algo,
&workspace_size_in_bytes
)); break;
}
workspace = cutlass::device_memory::allocation<char>(workspace_size_in_bytes);
}
/// Executes Conv2d operater from cudnn library
cudnnStatus_t operator()(cudnnHandle_t handle) {
switch (conv_kind) {
case library::ConvKind::kFprop:
return cudnnConvolutionForward(
handle,
&alpha,
activation_desc,
activation(),
filter_desc,
filter(),
conv_desc,
fprop_algo,
workspace.get(),
workspace_size_in_bytes,
&beta,
output_desc,
arguments.D
);
case library::ConvKind::kDgrad:
return cudnnConvolutionBackwardData(
handle,
&alpha,
filter_desc,
filter(),
output_desc,
output(),
conv_desc,
dgrad_algo,
workspace.get(),
workspace_size_in_bytes,
&beta,
activation_desc,
arguments.D
);
case library::ConvKind::kWgrad:
return cudnnConvolutionBackwardFilter(
handle,
&alpha,
activation_desc,
activation(),
output_desc,
output(),
conv_desc,
wgrad_algo,
workspace.get(),
workspace_size_in_bytes,
&beta,
filter_desc,
arguments.D
);
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
}
}
// Returns Actviation Tensor
void const * activation() const {
switch(conv_kind) {
case library::ConvKind::kFprop : return arguments.A;
case library::ConvKind::kDgrad : return arguments.C;
case library::ConvKind::kWgrad : return arguments.B;
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
}
}
// Returns Filter Tensor
void const *filter() const {
switch(conv_kind) {
case library::ConvKind::kFprop : return arguments.B;
case library::ConvKind::kDgrad : return arguments.B;
case library::ConvKind::kWgrad : return arguments.C;
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
}
}
// Returns Output Tensor
void const *output() const {
switch(conv_kind) {
case library::ConvKind::kFprop : return arguments.C;
case library::ConvKind::kDgrad : return arguments.A;
case library::ConvKind::kWgrad : return arguments.A;
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
}
}
};
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
#endif //#if CUTLASS_ENABLE_CUDNN
} // namespace profiler
} // namespace cutlass