591 lines
20 KiB
C++
591 lines
20 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/* \file
|
|
\brief Helper functions for mapping CUTLASS concepts to cuDNN.
|
|
|
|
*/
|
|
|
|
#pragma once
|
|
#if CUTLASS_ENABLE_CUDNN
|
|
#include <cuda_runtime.h>
|
|
#include <cudnn.h>
|
|
#include <iostream>
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/util/device_memory.h"
|
|
#include "cutlass/library/library.h"
|
|
#include "enumerated_types.h"
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
namespace profiler {
|
|
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// Converts a cuDNN status to cutlass::Status
|
|
Status get_cutlass_status(cudnnStatus_t cudnn_status);
|
|
|
|
/// Converts a cuDNN status to cutlass::profiler::Disposition
|
|
Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status);
|
|
|
|
/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception
|
|
Status checkCudnnErr(cudnnStatus_t cudnn_status);
|
|
|
|
/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration
|
|
bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode);
|
|
|
|
/// Maps a CUTLASS layout type to a cuDNN data type enumeration
|
|
bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout);
|
|
|
|
/// Maps a CUTLASS numeric type to a cuDNN data type enumeration
|
|
bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type);
|
|
|
|
/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type
|
|
bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc);
|
|
|
|
/// Returns a status if cudnn can satisfy a particular Conv2d description
|
|
Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration);
|
|
|
|
/// Returns a status if cudnn can satisfy a particular Conv3d description
|
|
Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration);
|
|
|
|
/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue)
|
|
float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src);
|
|
|
|
|
|
/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and
|
|
/// to destroy cudnnHandle_t on CudnnCreate object destruction.
|
|
/// Additionaly, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object
|
|
class CudnnCreate {
|
|
private:
|
|
cudnnHandle_t handle;
|
|
cudnnStatus_t status;
|
|
|
|
public:
|
|
CudnnCreate() {
|
|
status = cudnnCreate(&handle);
|
|
}
|
|
|
|
~CudnnCreate() {
|
|
cudnnDestroy(handle);
|
|
}
|
|
|
|
/// Implicit cast CudnnCreate object to cudnnHandle_t
|
|
operator cudnnHandle_t() const { return handle; }
|
|
|
|
/// returns cudnnStatus_t for handle creation
|
|
cudnnStatus_t get_cudnn_create_status() { return status; }
|
|
};
|
|
|
|
|
|
namespace detail {
|
|
|
|
/// Dispatcher to cudnn convolution operators
|
|
struct cudnnConvDispatcher {
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
//library::Conv2dConfiguration configuration;
|
|
library::ConvArguments arguments;
|
|
library::ConvKind conv_kind;
|
|
|
|
// cudnn-specific data structures to fill cudnn API call arguments
|
|
// cudnn activation, filter, and output descriptors
|
|
cudnnTensorDescriptor_t activation_desc;
|
|
cudnnFilterDescriptor_t filter_desc;
|
|
cudnnTensorDescriptor_t output_desc;
|
|
cudnnConvolutionDescriptor_t conv_desc;
|
|
|
|
// cudnn datatypes
|
|
cudnnDataType_t data_type_activation;
|
|
cudnnDataType_t data_type_filter;
|
|
cudnnDataType_t data_type_output;
|
|
|
|
// cudnn layouts
|
|
cudnnTensorFormat_t layout_activation;
|
|
cudnnTensorFormat_t layout_filter;
|
|
cudnnTensorFormat_t layout_output;
|
|
|
|
// cudnn convolution mode
|
|
cudnnConvolutionMode_t conv_mode;
|
|
|
|
// cudnn math type (tensorop, tensorop with conversion, simt)
|
|
cudnnMathType_t math_type;
|
|
|
|
// cudnn compute data type
|
|
cudnnDataType_t compute_type;
|
|
|
|
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
|
|
float alpha;
|
|
float beta;
|
|
|
|
// cudnn workspace
|
|
size_t workspace_size_in_bytes = 0;
|
|
cutlass::device_memory::allocation<char> workspace;
|
|
|
|
// select cudnn's implicit gemm precomputed algorithm with tensor operations
|
|
static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
|
|
static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
|
|
static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
|
|
|
|
Status status;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
// TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfigration
|
|
|
|
// ctor for conv2d
|
|
cudnnConvDispatcher(
|
|
library::ConvDescription const &op_desc,
|
|
library::Conv2dConfiguration configuration,
|
|
library::ConvArguments arguments_,
|
|
cudnnHandle_t handle
|
|
):
|
|
//configuration(configuration_),
|
|
arguments(arguments_),
|
|
conv_kind(op_desc.conv_kind),
|
|
status(Status::kSuccess) {
|
|
|
|
bool good = true;
|
|
|
|
// Get cudnn datatype, layout, and convolution mode from library::ConvDescription
|
|
good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element));
|
|
good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element));
|
|
good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element));
|
|
good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
|
|
good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
|
|
good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));
|
|
good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode));
|
|
// Get cudnn mathtype (cudnnMathType_t)
|
|
good = (good && get_cudnn_mathtype(math_type, op_desc));
|
|
good = (good && get_cudnn_datatype(
|
|
compute_type,
|
|
op_desc.tile_description.math_instruction.element_accumulator));
|
|
// Check cutlass Conv2d description has equivalent operator in cudnn
|
|
if (!good) {
|
|
status = Status::kErrorNotSupported;
|
|
return;
|
|
}
|
|
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
|
|
alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha);
|
|
beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta);
|
|
|
|
// Create convolution descriptor object
|
|
status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc));
|
|
|
|
// Configure convolution operator
|
|
std::vector<int> padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w};
|
|
std::vector<int> stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w};
|
|
std::vector<int> dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w};
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetConvolutionNdDescriptor(
|
|
conv_desc,
|
|
op_desc.conv_dim,
|
|
padding.data(),
|
|
stride.data(),
|
|
dilation.data(),
|
|
conv_mode,
|
|
compute_type
|
|
));
|
|
|
|
// Set groups
|
|
status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups));
|
|
|
|
// Create activation, filter, and output descriptor objects
|
|
status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc));
|
|
status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
|
|
status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));
|
|
|
|
// Set activation, filter, and output descriptor
|
|
status = get_cutlass_status(
|
|
cudnnSetTensor4dDescriptor(
|
|
activation_desc,
|
|
layout_activation,
|
|
data_type_activation,
|
|
configuration.problem_size.N,
|
|
configuration.problem_size.C,
|
|
configuration.problem_size.H,
|
|
configuration.problem_size.W
|
|
));
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetFilter4dDescriptor(
|
|
filter_desc,
|
|
data_type_filter,
|
|
layout_filter,
|
|
configuration.problem_size.K,
|
|
configuration.problem_size.C,
|
|
configuration.problem_size.R,
|
|
configuration.problem_size.S
|
|
));
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetTensor4dDescriptor(
|
|
output_desc,
|
|
layout_output,
|
|
data_type_output,
|
|
configuration.problem_size.N,
|
|
configuration.problem_size.K,
|
|
configuration.problem_size.P,
|
|
configuration.problem_size.Q
|
|
));
|
|
|
|
// Set math instruction to tensor op
|
|
status = get_cutlass_status(
|
|
cudnnSetConvolutionMathType(conv_desc, math_type));
|
|
|
|
// Initialize workspace
|
|
switch (conv_kind) {
|
|
case library::ConvKind::kFprop:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionForwardWorkspaceSize(
|
|
handle,
|
|
activation_desc,
|
|
filter_desc,
|
|
conv_desc,
|
|
output_desc,
|
|
fprop_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
case library::ConvKind::kDgrad:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
handle,
|
|
filter_desc,
|
|
output_desc,
|
|
conv_desc,
|
|
activation_desc,
|
|
dgrad_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
case library::ConvKind::kWgrad:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
handle,
|
|
activation_desc,
|
|
output_desc,
|
|
conv_desc,
|
|
filter_desc,
|
|
wgrad_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
|
|
}
|
|
|
|
workspace = cutlass::device_memory::allocation<char>(workspace_size_in_bytes);
|
|
}
|
|
|
|
|
|
// ctor for conv3d
|
|
cudnnConvDispatcher(
|
|
library::ConvDescription const &op_desc,
|
|
library::Conv3dConfiguration configuration,
|
|
library::ConvArguments arguments_,
|
|
cudnnHandle_t handle
|
|
):
|
|
//configuration(configuration_),
|
|
arguments(arguments_),
|
|
conv_kind(op_desc.conv_kind),
|
|
status(Status::kSuccess) {
|
|
|
|
bool good = true;
|
|
|
|
// Get cudnn datatype, layout, and convolution mode from library::ConvDescription
|
|
good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element));
|
|
good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element));
|
|
good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element));
|
|
|
|
good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout));
|
|
good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout));
|
|
good = (good && get_cudnn_layout(layout_output, op_desc.C.layout));
|
|
|
|
good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode));
|
|
|
|
// cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue)
|
|
alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha);
|
|
beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta);
|
|
|
|
good = (good && get_cudnn_datatype(
|
|
compute_type,
|
|
op_desc.tile_description.math_instruction.element_accumulator));
|
|
|
|
// Check cutlass Conv2d description has equivalent operator in cudnn
|
|
if (!good) {
|
|
status = Status::kErrorNotSupported;
|
|
}
|
|
|
|
// Create convolution descriptor object
|
|
status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc));
|
|
|
|
// Configure convolution operator
|
|
std::vector<int> padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w};
|
|
std::vector<int> stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w};
|
|
std::vector<int> dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w};
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetConvolutionNdDescriptor(
|
|
conv_desc,
|
|
op_desc.conv_dim,
|
|
padding.data(),
|
|
stride.data(),
|
|
dilation.data(),
|
|
conv_mode,
|
|
compute_type
|
|
));
|
|
|
|
// Set groups
|
|
status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups));
|
|
|
|
// Create activation, filter, and output descriptor objects
|
|
status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc));
|
|
status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc));
|
|
status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc));
|
|
|
|
// Set activation descriptor
|
|
std::vector<int> activation_extent {
|
|
configuration.problem_size.N,
|
|
configuration.problem_size.C,
|
|
configuration.problem_size.D,
|
|
configuration.problem_size.H,
|
|
configuration.problem_size.W
|
|
};
|
|
|
|
std::vector<int> activation_stride {
|
|
configuration.layout_activations.stride()[3],
|
|
1,
|
|
configuration.layout_activations.stride()[2],
|
|
configuration.layout_activations.stride()[1],
|
|
configuration.layout_activations.stride()[0]
|
|
};
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetTensorNdDescriptor(
|
|
activation_desc,
|
|
data_type_activation,
|
|
op_desc.conv_dim + 2,
|
|
activation_extent.data(),
|
|
activation_stride.data()
|
|
));
|
|
|
|
// Set filter descriptor
|
|
std::vector<int> filter_extent {
|
|
configuration.problem_size.K,
|
|
configuration.problem_size.C,
|
|
configuration.problem_size.T,
|
|
configuration.problem_size.R,
|
|
configuration.problem_size.S
|
|
};
|
|
|
|
std::vector<int> filter_stride {
|
|
configuration.layout_filters.stride()[3],
|
|
1,
|
|
configuration.layout_filters.stride()[2],
|
|
configuration.layout_filters.stride()[1],
|
|
configuration.layout_filters.stride()[0]
|
|
};
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetFilterNdDescriptor(
|
|
filter_desc,
|
|
data_type_filter,
|
|
layout_filter,
|
|
op_desc.conv_dim + 2,
|
|
filter_extent.data()
|
|
));
|
|
|
|
|
|
// Set output descriptor
|
|
std::vector<int> output_extent {
|
|
configuration.problem_size.N,
|
|
configuration.problem_size.K,
|
|
configuration.problem_size.Z,
|
|
configuration.problem_size.P,
|
|
configuration.problem_size.Q
|
|
};
|
|
|
|
std::vector<int> output_stride {
|
|
configuration.layout_output.stride()[3],
|
|
1,
|
|
configuration.layout_output.stride()[2],
|
|
configuration.layout_output.stride()[1],
|
|
configuration.layout_output.stride()[0]
|
|
};
|
|
|
|
status = get_cutlass_status(
|
|
cudnnSetTensorNdDescriptor(
|
|
output_desc,
|
|
data_type_output,
|
|
op_desc.conv_dim + 2,
|
|
output_extent.data(),
|
|
output_stride.data()
|
|
));
|
|
|
|
// Set math instruction to tensor op
|
|
status = get_cutlass_status(
|
|
cudnnSetConvolutionMathType(conv_desc, math_type));
|
|
|
|
// Initialize workspace
|
|
switch (conv_kind) {
|
|
case library::ConvKind::kFprop:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionForwardWorkspaceSize(
|
|
handle,
|
|
activation_desc,
|
|
filter_desc,
|
|
conv_desc,
|
|
output_desc,
|
|
fprop_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
case library::ConvKind::kDgrad:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionBackwardDataWorkspaceSize(
|
|
handle,
|
|
filter_desc,
|
|
output_desc,
|
|
conv_desc,
|
|
activation_desc,
|
|
dgrad_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
case library::ConvKind::kWgrad:
|
|
status = get_cutlass_status(
|
|
cudnnGetConvolutionBackwardFilterWorkspaceSize(
|
|
handle,
|
|
activation_desc,
|
|
output_desc,
|
|
conv_desc,
|
|
filter_desc,
|
|
wgrad_algo,
|
|
&workspace_size_in_bytes
|
|
)); break;
|
|
|
|
}
|
|
|
|
workspace = cutlass::device_memory::allocation<char>(workspace_size_in_bytes);
|
|
}
|
|
|
|
/// Executes Conv2d operater from cudnn library
|
|
cudnnStatus_t operator()(cudnnHandle_t handle) {
|
|
|
|
switch (conv_kind) {
|
|
case library::ConvKind::kFprop:
|
|
return cudnnConvolutionForward(
|
|
handle,
|
|
&alpha,
|
|
activation_desc,
|
|
activation(),
|
|
filter_desc,
|
|
filter(),
|
|
conv_desc,
|
|
fprop_algo,
|
|
workspace.get(),
|
|
workspace_size_in_bytes,
|
|
&beta,
|
|
output_desc,
|
|
arguments.D
|
|
);
|
|
case library::ConvKind::kDgrad:
|
|
return cudnnConvolutionBackwardData(
|
|
handle,
|
|
&alpha,
|
|
filter_desc,
|
|
filter(),
|
|
output_desc,
|
|
output(),
|
|
conv_desc,
|
|
dgrad_algo,
|
|
workspace.get(),
|
|
workspace_size_in_bytes,
|
|
&beta,
|
|
activation_desc,
|
|
arguments.D
|
|
);
|
|
case library::ConvKind::kWgrad:
|
|
return cudnnConvolutionBackwardFilter(
|
|
handle,
|
|
&alpha,
|
|
activation_desc,
|
|
activation(),
|
|
output_desc,
|
|
output(),
|
|
conv_desc,
|
|
wgrad_algo,
|
|
workspace.get(),
|
|
workspace_size_in_bytes,
|
|
&beta,
|
|
filter_desc,
|
|
arguments.D
|
|
);
|
|
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
|
}
|
|
}
|
|
|
|
// Returns Actviation Tensor
|
|
void const * activation() const {
|
|
switch(conv_kind) {
|
|
case library::ConvKind::kFprop : return arguments.A;
|
|
case library::ConvKind::kDgrad : return arguments.C;
|
|
case library::ConvKind::kWgrad : return arguments.B;
|
|
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
|
}
|
|
}
|
|
|
|
// Returns Filter Tensor
|
|
void const *filter() const {
|
|
switch(conv_kind) {
|
|
case library::ConvKind::kFprop : return arguments.B;
|
|
case library::ConvKind::kDgrad : return arguments.B;
|
|
case library::ConvKind::kWgrad : return arguments.C;
|
|
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
|
}
|
|
}
|
|
|
|
// Returns Output Tensor
|
|
void const *output() const {
|
|
switch(conv_kind) {
|
|
case library::ConvKind::kFprop : return arguments.C;
|
|
case library::ConvKind::kDgrad : return arguments.A;
|
|
case library::ConvKind::kWgrad : return arguments.A;
|
|
default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)");
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
#endif //#if CUTLASS_ENABLE_CUDNN
|
|
} // namespace profiler
|
|
} // namespace cutlass
|