2020-11-20 13:25:25 +08:00
|
|
|
/***************************************************************************************************
|
2023-01-21 05:32:57 +08:00
|
|
|
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2022-04-24 03:02:38 +08:00
|
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
2020-11-20 13:25:25 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* Redistribution and use in source and binary forms, with or without
|
|
|
|
* modification, are permitted provided that the following conditions are met:
|
2020-11-20 13:25:25 +08:00
|
|
|
*
|
2022-04-24 03:02:38 +08:00
|
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
|
|
* list of conditions and the following disclaimer.
|
|
|
|
*
|
|
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
|
|
* and/or other materials provided with the distribution.
|
|
|
|
*
|
|
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
|
|
* contributors may be used to endorse or promote products derived from
|
|
|
|
* this software without specific prior written permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
2020-11-20 13:25:25 +08:00
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
/* \file
|
|
|
|
\brief Defines operations for all CONV operation kinds in CUTLASS Library
|
|
|
|
*/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
#include <sstream>
|
|
|
|
#include <cstring>
|
|
|
|
|
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
|
|
|
|
#include "cutlass/library/library.h"
|
|
|
|
#include "cutlass/library/manifest.h"
|
|
|
|
#include "cutlass/library/util.h"
|
|
|
|
#include "library_internal.h"
|
|
|
|
|
|
|
|
#include "cutlass/util/reference/host/convolution.h"
|
|
|
|
#include "cutlass/util/reference/device/convolution.h"
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
namespace cutlass {
|
|
|
|
namespace library {
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
namespace detail {
|
|
|
|
|
|
|
|
template <
|
|
|
|
Provider kProvider,
|
|
|
|
conv::Operator ConvolutionalOperator,
|
|
|
|
int ConvDim,
|
|
|
|
typename ElementA_,
|
|
|
|
typename LayoutA_,
|
|
|
|
typename ElementB_,
|
|
|
|
typename LayoutB_,
|
|
|
|
typename ElementC_,
|
|
|
|
typename LayoutC_,
|
|
|
|
typename ElementCompute_,
|
|
|
|
typename ElementAccumulator_ = ElementCompute_,
|
|
|
|
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
|
|
|
|
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
|
|
|
>
|
|
|
|
struct ConvReferenceDispatcher;
|
|
|
|
|
2023-03-10 12:22:56 +08:00
|
|
|
/// Dispatcher for Conv2d (partially specialized for kConvDim == 2)
|
2020-11-20 13:25:25 +08:00
|
|
|
template <
|
|
|
|
Provider kProvider,
|
|
|
|
conv::Operator kConvolutionalOperator,
|
|
|
|
typename ElementA,
|
|
|
|
typename LayoutA,
|
|
|
|
typename ElementB,
|
|
|
|
typename LayoutB,
|
|
|
|
typename ElementC,
|
|
|
|
typename LayoutC,
|
|
|
|
typename ElementCompute,
|
|
|
|
typename ElementAccumulator,
|
|
|
|
typename ConvertOp,
|
|
|
|
typename InnerProductOp
|
|
|
|
>
|
|
|
|
struct ConvReferenceDispatcher<
|
|
|
|
kProvider,
|
|
|
|
kConvolutionalOperator,
|
|
|
|
2,
|
|
|
|
ElementA, LayoutA,
|
|
|
|
ElementB, LayoutB,
|
|
|
|
ElementC, LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp> {
|
|
|
|
|
|
|
|
static Status dispatch(
|
|
|
|
void const *configuration,
|
|
|
|
ElementA *ptr_A,
|
|
|
|
ElementB *ptr_B,
|
|
|
|
ElementC *ptr_C,
|
|
|
|
ElementC *ptr_D,
|
|
|
|
ElementCompute alpha,
|
|
|
|
ElementCompute beta,
|
|
|
|
cudaStream_t stream = nullptr
|
|
|
|
) {
|
|
|
|
|
|
|
|
Conv2dConfiguration const &config =
|
|
|
|
*static_cast<Conv2dConfiguration const *>(configuration);
|
|
|
|
|
2021-02-26 22:58:26 +08:00
|
|
|
// TODO: make below code more general. It is fixed for NHWC now.
|
|
|
|
layout::TensorNHWC layout_a;
|
|
|
|
layout::TensorNHWC layout_b;
|
|
|
|
layout::TensorNHWC layout_c;
|
|
|
|
|
|
|
|
layout_a.stride() =
|
2021-07-23 12:40:53 +08:00
|
|
|
make_Coord(int32_t(config.stride_a[0]),
|
|
|
|
int32_t(config.stride_a[1]),
|
|
|
|
int32_t(config.stride_a[2]));
|
2021-02-26 22:58:26 +08:00
|
|
|
|
|
|
|
layout_b.stride() =
|
2021-07-23 12:40:53 +08:00
|
|
|
make_Coord(int32_t(config.stride_b[0]),
|
|
|
|
int32_t(config.stride_b[1]),
|
|
|
|
int32_t(config.stride_b[2]));
|
2021-02-26 22:58:26 +08:00
|
|
|
|
|
|
|
layout_c.stride() =
|
2021-07-23 12:40:53 +08:00
|
|
|
make_Coord(int32_t(config.stride_c[0]),
|
|
|
|
int32_t(config.stride_c[1]),
|
|
|
|
int32_t(config.stride_c[2]));
|
2020-11-20 13:25:25 +08:00
|
|
|
|
|
|
|
if (kProvider == Provider::kReferenceHost) {
|
|
|
|
|
|
|
|
cutlass::reference::host::Conv2d<
|
|
|
|
ElementA,
|
|
|
|
LayoutA,
|
|
|
|
ElementB,
|
|
|
|
LayoutB,
|
|
|
|
ElementC ,
|
|
|
|
LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
2023-04-15 11:19:34 +08:00
|
|
|
ElementC,
|
2020-11-20 13:25:25 +08:00
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp
|
|
|
|
>(
|
|
|
|
kConvolutionalOperator,
|
|
|
|
config.problem_size,
|
2021-02-26 22:58:26 +08:00
|
|
|
{ptr_A, layout_a},
|
|
|
|
{ptr_B, layout_b},
|
|
|
|
{ptr_C, layout_c},
|
|
|
|
{ptr_D, layout_c},
|
2020-11-20 13:25:25 +08:00
|
|
|
alpha,
|
|
|
|
beta
|
|
|
|
);
|
|
|
|
|
|
|
|
return Status::kSuccess;
|
|
|
|
}
|
|
|
|
else if (kProvider == Provider::kReferenceDevice) {
|
|
|
|
return cutlass::reference::device::Conv2d<
|
|
|
|
ElementA,
|
|
|
|
LayoutA,
|
|
|
|
ElementB,
|
|
|
|
LayoutB,
|
|
|
|
ElementC,
|
|
|
|
LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp
|
|
|
|
>(
|
|
|
|
kConvolutionalOperator,
|
|
|
|
config.problem_size,
|
2021-02-26 22:58:26 +08:00
|
|
|
{ptr_A, layout_a},
|
|
|
|
{ptr_B, layout_b},
|
|
|
|
{ptr_C, layout_c},
|
|
|
|
{ptr_D, layout_c},
|
2020-11-20 13:25:25 +08:00
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
stream
|
|
|
|
);
|
|
|
|
}
|
|
|
|
return Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Dispatcher for Conv3d (partially specialized for kConvDim == 3)
|
|
|
|
template <
|
|
|
|
Provider kProvider,
|
|
|
|
conv::Operator kConvolutionalOperator,
|
|
|
|
typename ElementA,
|
|
|
|
typename LayoutA,
|
|
|
|
typename ElementB,
|
|
|
|
typename LayoutB,
|
|
|
|
typename ElementC,
|
|
|
|
typename LayoutC,
|
|
|
|
typename ElementCompute,
|
|
|
|
typename ElementAccumulator,
|
|
|
|
typename ConvertOp,
|
|
|
|
typename InnerProductOp
|
|
|
|
>
|
|
|
|
struct ConvReferenceDispatcher<
|
|
|
|
kProvider,
|
|
|
|
kConvolutionalOperator,
|
|
|
|
3,
|
|
|
|
ElementA, LayoutA,
|
|
|
|
ElementB, LayoutB,
|
|
|
|
ElementC, LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp> {
|
|
|
|
|
|
|
|
static Status dispatch(
|
|
|
|
void const *configuration,
|
|
|
|
ElementA *ptr_A,
|
|
|
|
ElementB *ptr_B,
|
|
|
|
ElementC *ptr_C,
|
|
|
|
ElementC *ptr_D,
|
|
|
|
ElementCompute alpha,
|
|
|
|
ElementCompute beta,
|
|
|
|
cudaStream_t stream = nullptr
|
|
|
|
) {
|
|
|
|
|
|
|
|
Conv3dConfiguration const &config =
|
|
|
|
*static_cast<Conv3dConfiguration const *>(configuration);
|
|
|
|
|
|
|
|
ConvKind const conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
|
|
|
|
|
|
|
|
if (kProvider == Provider::kReferenceHost) {
|
|
|
|
cutlass::reference::host::Conv3d<
|
|
|
|
ElementA,
|
|
|
|
LayoutA,
|
|
|
|
ElementB,
|
|
|
|
LayoutB,
|
|
|
|
ElementC ,
|
|
|
|
LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp
|
|
|
|
>(
|
|
|
|
kConvolutionalOperator,
|
|
|
|
config.problem_size,
|
|
|
|
{ptr_A, config.layout_a(conv_kind)},
|
|
|
|
{ptr_B, config.layout_b(conv_kind)},
|
|
|
|
{ptr_C, config.layout_c(conv_kind)},
|
|
|
|
{ptr_D, config.layout_c(conv_kind)},
|
|
|
|
alpha,
|
|
|
|
beta
|
|
|
|
);
|
|
|
|
|
|
|
|
return Status::kSuccess;
|
|
|
|
}
|
|
|
|
else if (kProvider == Provider::kReferenceDevice) {
|
|
|
|
return cutlass::reference::device::Conv3d<
|
|
|
|
ElementA,
|
|
|
|
LayoutA,
|
|
|
|
ElementB,
|
|
|
|
LayoutB,
|
|
|
|
ElementC,
|
|
|
|
LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp
|
|
|
|
>(
|
|
|
|
kConvolutionalOperator,
|
|
|
|
config.problem_size,
|
|
|
|
{ptr_A, config.layout_a(conv_kind)},
|
|
|
|
{ptr_B, config.layout_b(conv_kind)},
|
|
|
|
{ptr_C, config.layout_c(conv_kind)},
|
|
|
|
{ptr_D, config.layout_c(conv_kind)},
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
stream
|
|
|
|
);
|
|
|
|
}
|
|
|
|
return Status::kErrorNotSupported;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
} // namespace detail
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
template <
|
|
|
|
Provider Provider_,
|
|
|
|
conv::Operator ConvolutionalOperator,
|
|
|
|
int ConvDim,
|
|
|
|
typename ElementA_,
|
|
|
|
typename LayoutA_,
|
|
|
|
typename ElementB_,
|
|
|
|
typename LayoutB_,
|
|
|
|
typename ElementC_,
|
|
|
|
typename LayoutC_,
|
|
|
|
typename ElementCompute_,
|
|
|
|
typename ElementAccumulator_ = ElementCompute_,
|
|
|
|
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
|
|
|
|
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
|
|
|
>
|
|
|
|
class ConvReferenceOperation : public Operation {
|
|
|
|
public:
|
|
|
|
static Provider const kProvider = Provider_;
|
|
|
|
static conv::Operator const kConvolutionalOperator = ConvolutionalOperator;
|
|
|
|
static int const kConvDim = ConvDim;
|
|
|
|
|
|
|
|
using ElementA = ElementA_;
|
|
|
|
using LayoutA = LayoutA_;
|
|
|
|
using ElementB = ElementB_;
|
|
|
|
using LayoutB = LayoutB_;
|
|
|
|
using ElementC = ElementC_;
|
|
|
|
using LayoutC = LayoutC_;
|
|
|
|
using ElementCompute = ElementCompute_;
|
|
|
|
using ElementAccumulator = ElementAccumulator_;
|
|
|
|
using ConvertOp = ConvertOp_;
|
|
|
|
using InnerProductOp = InnerProductOp_;
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
/// Storage for the name string
|
|
|
|
std::string name_;
|
|
|
|
|
|
|
|
///
|
|
|
|
ConvDescription description_;
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
|
|
|
/// Constructor
|
|
|
|
ConvReferenceOperation() {
|
|
|
|
|
|
|
|
// Basic information
|
|
|
|
description_.provider = kProvider;
|
|
|
|
description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d);
|
|
|
|
description_.conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
|
|
|
|
description_.conv_dim = kConvDim;
|
|
|
|
|
|
|
|
// Tensor description
|
|
|
|
description_.A = make_TensorDescription<ElementA, LayoutA>();
|
|
|
|
description_.B = make_TensorDescription<ElementB, LayoutB>();
|
|
|
|
description_.C = make_TensorDescription<ElementC, LayoutC>();
|
|
|
|
|
|
|
|
// Epilogue compute and accumulator type description
|
|
|
|
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
|
|
|
|
|
|
|
description_.tile_description.math_instruction.element_accumulator =
|
|
|
|
NumericTypeMap<ElementAccumulator>::kId;
|
|
|
|
|
|
|
|
// Iterator algorithm for convolution reference
|
|
|
|
description_.iterator_algorithm = IteratorAlgorithmID::kNone;
|
|
|
|
|
|
|
|
// Compute capability for convolution reference
|
|
|
|
description_.tile_description.minimum_compute_capability =
|
|
|
|
(kProvider == Provider::kReferenceDevice ? 50 : 0);
|
|
|
|
|
|
|
|
description_.tile_description.maximum_compute_capability = 1024;
|
|
|
|
|
|
|
|
// Procedural name
|
|
|
|
std::stringstream ss;
|
|
|
|
|
|
|
|
ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind)
|
|
|
|
<< "_reference_" << to_string(description_.provider)
|
|
|
|
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
|
|
|
|
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
|
|
|
|
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
|
|
|
|
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
|
|
|
|
|
|
|
|
name_ = ss.str();
|
|
|
|
|
|
|
|
description_.name = name_.c_str();
|
|
|
|
|
|
|
|
// Epilogue compute and accumulator type description
|
|
|
|
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
|
|
|
|
|
|
|
|
description_.tile_description.math_instruction.element_accumulator =
|
|
|
|
NumericTypeMap<ElementAccumulator>::kId;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns the description of the GEMM operation
|
|
|
|
virtual OperationDescription const & description() const {
|
|
|
|
return description_;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual Status can_implement(
|
|
|
|
void const *configuration,
|
|
|
|
void const *arguments) const {
|
|
|
|
|
|
|
|
return Status::kSuccess;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual uint64_t get_host_workspace_size(
|
|
|
|
void const *configuration) const {
|
|
|
|
|
|
|
|
switch (kConvDim) {
|
|
|
|
case 2:
|
|
|
|
return sizeof(Conv2dConfiguration);
|
|
|
|
case 3:
|
|
|
|
return sizeof(Conv3dConfiguration);
|
|
|
|
default:
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual uint64_t get_device_workspace_size(
|
2022-01-27 23:37:37 +08:00
|
|
|
void const *configuration,
|
|
|
|
void const *arguments = nullptr) const {
|
2020-11-20 13:25:25 +08:00
|
|
|
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual Status initialize(
|
|
|
|
void const *configuration,
|
|
|
|
void *host_workspace,
|
|
|
|
void *device_workspace = nullptr,
|
|
|
|
cudaStream_t stream = nullptr) const {
|
|
|
|
|
|
|
|
std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration));
|
|
|
|
|
|
|
|
return Status::kSuccess;
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual Status run(
|
|
|
|
void const *arguments,
|
|
|
|
void *host_workspace,
|
|
|
|
void *device_workspace = nullptr,
|
|
|
|
cudaStream_t stream = nullptr) const {
|
|
|
|
|
|
|
|
ConvArguments const &args = *static_cast<ConvArguments const *>(arguments);
|
|
|
|
|
|
|
|
ElementCompute alpha;
|
|
|
|
ElementCompute beta;
|
|
|
|
|
|
|
|
alpha = *static_cast<ElementCompute const *>(args.alpha);
|
|
|
|
beta = *static_cast<ElementCompute const *>(args.beta);
|
|
|
|
|
|
|
|
// TODO - respect pointer mode
|
|
|
|
|
|
|
|
// Invoke 2D or 3D convolution
|
|
|
|
return detail::ConvReferenceDispatcher<
|
|
|
|
kProvider,
|
|
|
|
kConvolutionalOperator,
|
|
|
|
kConvDim,
|
|
|
|
ElementA,
|
|
|
|
LayoutA,
|
|
|
|
ElementB,
|
|
|
|
LayoutB,
|
|
|
|
ElementC,
|
|
|
|
LayoutC,
|
|
|
|
ElementCompute,
|
|
|
|
ElementAccumulator,
|
|
|
|
ConvertOp,
|
|
|
|
InnerProductOp
|
|
|
|
>::dispatch(
|
|
|
|
host_workspace,
|
|
|
|
static_cast<ElementA *>(const_cast<void *>(args.A)),
|
|
|
|
static_cast<ElementB *>(const_cast<void *>(args.B)),
|
|
|
|
static_cast<ElementC *>(const_cast<void *>(args.C)),
|
|
|
|
static_cast<ElementC *>(args.D),
|
|
|
|
alpha,
|
|
|
|
beta,
|
|
|
|
stream
|
|
|
|
);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
/// Constructs Fprop reference operators.
|
|
|
|
template <
|
|
|
|
int kConvDim,
|
|
|
|
typename ElementA_,
|
|
|
|
typename LayoutA_,
|
|
|
|
typename ElementB_,
|
|
|
|
typename LayoutB_,
|
|
|
|
typename ElementC_,
|
|
|
|
typename LayoutC_,
|
|
|
|
typename ElementCompute_,
|
|
|
|
typename ElementAccumulator_ = ElementCompute_,
|
|
|
|
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
|
|
|
|
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
|
|
|
>
|
|
|
|
void make_conv_fprop(Manifest &manifest) {
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceHost,
|
|
|
|
conv::Operator::kFprop,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceDevice,
|
|
|
|
conv::Operator::kFprop,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Constructs Dgrad and Wgrad reference operators.
|
|
|
|
template <
|
|
|
|
int kConvDim,
|
|
|
|
typename ElementA_,
|
|
|
|
typename LayoutA_,
|
|
|
|
typename ElementB_,
|
|
|
|
typename LayoutB_,
|
|
|
|
typename ElementC_,
|
|
|
|
typename LayoutC_,
|
|
|
|
typename ElementCompute_,
|
|
|
|
typename ElementAccumulator_ = ElementCompute_,
|
|
|
|
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
|
|
|
|
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
|
|
|
>
|
|
|
|
void make_conv_backwards(Manifest &manifest) {
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceHost,
|
|
|
|
conv::Operator::kDgrad,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceDevice,
|
|
|
|
conv::Operator::kDgrad,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceHost,
|
|
|
|
conv::Operator::kWgrad,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
|
|
|
|
manifest.append(new ConvReferenceOperation<
|
|
|
|
Provider::kReferenceDevice,
|
|
|
|
conv::Operator::kWgrad,
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Six operators for the price of one.
|
|
|
|
template <
|
|
|
|
int kConvDim,
|
|
|
|
typename ElementA_,
|
|
|
|
typename LayoutA_,
|
|
|
|
typename ElementB_,
|
|
|
|
typename LayoutB_,
|
|
|
|
typename ElementC_,
|
|
|
|
typename LayoutC_,
|
|
|
|
typename ElementCompute_,
|
|
|
|
typename ElementAccumulator_ = ElementCompute_,
|
|
|
|
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
|
|
|
|
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
|
|
|
|
>
|
|
|
|
void make_conv_all(Manifest &manifest) {
|
|
|
|
|
|
|
|
make_conv_fprop<
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>(manifest);
|
|
|
|
|
|
|
|
make_conv_backwards<
|
|
|
|
kConvDim,
|
|
|
|
ElementA_, LayoutA_,
|
|
|
|
ElementB_, LayoutB_,
|
|
|
|
ElementC_, LayoutC_,
|
|
|
|
ElementCompute_,
|
|
|
|
ElementAccumulator_,
|
|
|
|
ConvertOp_,
|
|
|
|
InnerProductOp_
|
|
|
|
>(manifest);
|
|
|
|
}
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
} // namespace library
|
|
|
|
} // namespace cutlass
|
|
|
|
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|