cutlass/tools/library/src/reference/conv_reference_operation.h

634 lines
17 KiB
C
Raw Normal View History

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2022-04-24 03:02:38 +08:00
* SPDX-License-Identifier: BSD-3-Clause
*
2022-04-24 03:02:38 +08:00
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
2022-04-24 03:02:38 +08:00
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/* \file
\brief Defines operations for all CONV operation kinds in CUTLASS Library
*/
#pragma once
#include <iostream>
#include <sstream>
#include <cstring>
#include "cutlass/cutlass.h"
#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "library_internal.h"
#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/device/convolution.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace library {
///////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
template <
Provider kProvider,
conv::Operator ConvolutionalOperator,
int ConvDim,
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
struct ConvReferenceDispatcher;
/// Dispatcher for Conv2d (partially specialized for kConvDim == 2)
template <
Provider kProvider,
conv::Operator kConvolutionalOperator,
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator,
typename ConvertOp,
typename InnerProductOp
>
struct ConvReferenceDispatcher<
kProvider,
kConvolutionalOperator,
2,
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp> {
static Status dispatch(
void const *configuration,
ElementA *ptr_A,
ElementB *ptr_B,
ElementC *ptr_C,
ElementC *ptr_D,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr
) {
Conv2dConfiguration const &config =
*static_cast<Conv2dConfiguration const *>(configuration);
2021-02-26 22:58:26 +08:00
// TODO: make below code more general. It is fixed for NHWC now.
layout::TensorNHWC layout_a;
layout::TensorNHWC layout_b;
layout::TensorNHWC layout_c;
layout_a.stride() =
2021-07-23 12:40:53 +08:00
make_Coord(int32_t(config.stride_a[0]),
int32_t(config.stride_a[1]),
int32_t(config.stride_a[2]));
2021-02-26 22:58:26 +08:00
layout_b.stride() =
2021-07-23 12:40:53 +08:00
make_Coord(int32_t(config.stride_b[0]),
int32_t(config.stride_b[1]),
int32_t(config.stride_b[2]));
2021-02-26 22:58:26 +08:00
layout_c.stride() =
2021-07-23 12:40:53 +08:00
make_Coord(int32_t(config.stride_c[0]),
int32_t(config.stride_c[1]),
int32_t(config.stride_c[2]));
if (kProvider == Provider::kReferenceHost) {
cutlass::reference::host::Conv2d<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC ,
LayoutC,
ElementCompute,
ElementAccumulator,
ElementC,
ConvertOp,
InnerProductOp
>(
kConvolutionalOperator,
config.problem_size,
2021-02-26 22:58:26 +08:00
{ptr_A, layout_a},
{ptr_B, layout_b},
{ptr_C, layout_c},
{ptr_D, layout_c},
alpha,
beta
);
return Status::kSuccess;
}
else if (kProvider == Provider::kReferenceDevice) {
return cutlass::reference::device::Conv2d<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>(
kConvolutionalOperator,
config.problem_size,
2021-02-26 22:58:26 +08:00
{ptr_A, layout_a},
{ptr_B, layout_b},
{ptr_C, layout_c},
{ptr_D, layout_c},
alpha,
beta,
stream
);
}
return Status::kErrorNotSupported;
}
};
/// Dispatcher for Conv3d (partially specialized for kConvDim == 3)
template <
Provider kProvider,
conv::Operator kConvolutionalOperator,
typename ElementA,
typename LayoutA,
typename ElementB,
typename LayoutB,
typename ElementC,
typename LayoutC,
typename ElementCompute,
typename ElementAccumulator,
typename ConvertOp,
typename InnerProductOp
>
struct ConvReferenceDispatcher<
kProvider,
kConvolutionalOperator,
3,
ElementA, LayoutA,
ElementB, LayoutB,
ElementC, LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp> {
static Status dispatch(
void const *configuration,
ElementA *ptr_A,
ElementB *ptr_B,
ElementC *ptr_C,
ElementC *ptr_D,
ElementCompute alpha,
ElementCompute beta,
cudaStream_t stream = nullptr
) {
Conv3dConfiguration const &config =
*static_cast<Conv3dConfiguration const *>(configuration);
ConvKind const conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
if (kProvider == Provider::kReferenceHost) {
cutlass::reference::host::Conv3d<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC ,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>(
kConvolutionalOperator,
config.problem_size,
{ptr_A, config.layout_a(conv_kind)},
{ptr_B, config.layout_b(conv_kind)},
{ptr_C, config.layout_c(conv_kind)},
{ptr_D, config.layout_c(conv_kind)},
alpha,
beta
);
return Status::kSuccess;
}
else if (kProvider == Provider::kReferenceDevice) {
return cutlass::reference::device::Conv3d<
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>(
kConvolutionalOperator,
config.problem_size,
{ptr_A, config.layout_a(conv_kind)},
{ptr_B, config.layout_b(conv_kind)},
{ptr_C, config.layout_c(conv_kind)},
{ptr_D, config.layout_c(conv_kind)},
alpha,
beta,
stream
);
}
return Status::kErrorNotSupported;
}
};
} // namespace detail
///////////////////////////////////////////////////////////////////////////////////////////////////
template <
Provider Provider_,
conv::Operator ConvolutionalOperator,
int ConvDim,
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
class ConvReferenceOperation : public Operation {
public:
static Provider const kProvider = Provider_;
static conv::Operator const kConvolutionalOperator = ConvolutionalOperator;
static int const kConvDim = ConvDim;
using ElementA = ElementA_;
using LayoutA = LayoutA_;
using ElementB = ElementB_;
using LayoutB = LayoutB_;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementAccumulator_;
using ConvertOp = ConvertOp_;
using InnerProductOp = InnerProductOp_;
protected:
/// Storage for the name string
std::string name_;
///
ConvDescription description_;
public:
/// Constructor
ConvReferenceOperation() {
// Basic information
description_.provider = kProvider;
description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d);
description_.conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
description_.conv_dim = kConvDim;
// Tensor description
description_.A = make_TensorDescription<ElementA, LayoutA>();
description_.B = make_TensorDescription<ElementB, LayoutB>();
description_.C = make_TensorDescription<ElementC, LayoutC>();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
// Iterator algorithm for convolution reference
description_.iterator_algorithm = IteratorAlgorithmID::kNone;
// Compute capability for convolution reference
description_.tile_description.minimum_compute_capability =
(kProvider == Provider::kReferenceDevice ? 50 : 0);
description_.tile_description.maximum_compute_capability = 1024;
// Procedural name
std::stringstream ss;
ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind)
<< "_reference_" << to_string(description_.provider)
<< "_" << to_string(description_.A.element) << to_string(description_.A.layout)
<< "_" << to_string(description_.B.element) << to_string(description_.B.layout)
<< "_" << to_string(description_.C.element) << to_string(description_.C.layout)
<< "_" << to_string(description_.tile_description.math_instruction.element_accumulator);
name_ = ss.str();
description_.name = name_.c_str();
// Epilogue compute and accumulator type description
description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;
description_.tile_description.math_instruction.element_accumulator =
NumericTypeMap<ElementAccumulator>::kId;
}
/// Returns the description of the GEMM operation
virtual OperationDescription const & description() const {
return description_;
}
virtual Status can_implement(
void const *configuration,
void const *arguments) const {
return Status::kSuccess;
}
virtual uint64_t get_host_workspace_size(
void const *configuration) const {
switch (kConvDim) {
case 2:
return sizeof(Conv2dConfiguration);
case 3:
return sizeof(Conv3dConfiguration);
default:
break;
}
return 0;
}
virtual uint64_t get_device_workspace_size(
void const *configuration,
void const *arguments = nullptr) const {
return 0;
}
virtual Status initialize(
void const *configuration,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration));
return Status::kSuccess;
}
virtual Status run(
void const *arguments,
void *host_workspace,
void *device_workspace = nullptr,
cudaStream_t stream = nullptr) const {
ConvArguments const &args = *static_cast<ConvArguments const *>(arguments);
ElementCompute alpha;
ElementCompute beta;
alpha = *static_cast<ElementCompute const *>(args.alpha);
beta = *static_cast<ElementCompute const *>(args.beta);
// TODO - respect pointer mode
// Invoke 2D or 3D convolution
return detail::ConvReferenceDispatcher<
kProvider,
kConvolutionalOperator,
kConvDim,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementC,
LayoutC,
ElementCompute,
ElementAccumulator,
ConvertOp,
InnerProductOp
>::dispatch(
host_workspace,
static_cast<ElementA *>(const_cast<void *>(args.A)),
static_cast<ElementB *>(const_cast<void *>(args.B)),
static_cast<ElementC *>(const_cast<void *>(args.C)),
static_cast<ElementC *>(args.D),
alpha,
beta,
stream
);
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Constructs Fprop reference operators.
template <
int kConvDim,
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_conv_fprop(Manifest &manifest) {
manifest.append(new ConvReferenceOperation<
Provider::kReferenceHost,
conv::Operator::kFprop,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new ConvReferenceOperation<
Provider::kReferenceDevice,
conv::Operator::kFprop,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
}
/// Constructs Dgrad and Wgrad reference operators.
template <
int kConvDim,
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_conv_backwards(Manifest &manifest) {
manifest.append(new ConvReferenceOperation<
Provider::kReferenceHost,
conv::Operator::kDgrad,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new ConvReferenceOperation<
Provider::kReferenceDevice,
conv::Operator::kDgrad,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new ConvReferenceOperation<
Provider::kReferenceHost,
conv::Operator::kWgrad,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
manifest.append(new ConvReferenceOperation<
Provider::kReferenceDevice,
conv::Operator::kWgrad,
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>);
}
/// Six operators for the price of one.
template <
int kConvDim,
typename ElementA_,
typename LayoutA_,
typename ElementB_,
typename LayoutB_,
typename ElementC_,
typename LayoutC_,
typename ElementCompute_,
typename ElementAccumulator_ = ElementCompute_,
typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
typename InnerProductOp_ = multiply_add<ElementAccumulator_>
>
void make_conv_all(Manifest &manifest) {
make_conv_fprop<
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
make_conv_backwards<
kConvDim,
ElementA_, LayoutA_,
ElementB_, LayoutB_,
ElementC_, LayoutC_,
ElementCompute_,
ElementAccumulator_,
ConvertOp_,
InnerProductOp_
>(manifest);
}
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace library
} // namespace cutlass
///////////////////////////////////////////////////////////////////////////////////////////////////