cutlass/include/cutlass/conv/conv3d_problem_size.h
2024-04-11 21:33:40 -04:00

514 lines
18 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This file contains definitions and utility functions for describing convolution problem sizes.
Conv3dProblem desciption:
activation (NDHWC),
filter (KTRSC),
output (NZPQK),
pading (pad_d, pad_h, pad_w),
stride (stride_d, stride_h, stride_w),
dilation (dilation_d, dilation_h, dilation_w).
Free functions to map:
Map tensor extents (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_extent(ConvolutionOperator)
Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator)
Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator)
*/
#pragma once
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"
namespace cutlass {
namespace conv {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Problem size structure
struct Conv3dProblemSize : public Conv2dProblemSize {
//
// Type definitions
//
// 3D coordinate for padding, stride, and dilation in (d, h, w) dimensions
using Coord3D = Coord<3>;
//
// Data members
//
// Conv3d strictly problem size parameters
int D, T, Z; // input depth, filter depth, output depth
int pad_d; // padding in depth dimension
int stride_d; // stride in depth dimension
int dilation_d; // dilation in depth dimension
//
// Methods
//
public:
CUTLASS_HOST_DEVICE
Conv3dProblemSize():
Conv2dProblemSize(),
D(0), T(0), Z(0),
pad_d(0),
stride_d(1),
dilation_d(1) { }
/// Constructor for default padding, stride, dilation, and split-K
CUTLASS_HOST_DEVICE
Conv3dProblemSize(
int N,
int D,
int H,
int W,
int C,
int Z,
int P,
int Q,
int K,
int T,
int R,
int S,
Mode mode
):
Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode),
D(D), T(T), Z(Z),
pad_d(T / 2), stride_d(1), dilation_d(1) { }
/// Constructor
CUTLASS_HOST_DEVICE
Conv3dProblemSize(
int N,
int D,
int H,
int W,
int C,
int K,
int T,
int R,
int S,
int Z,
int P,
int Q,
int pad_d,
int pad_h,
int pad_w,
int stride_d,
int stride_h,
int stride_w,
int dilation_d,
int dilation_h,
int dilation_w,
Mode mode,
int split_k_slices = 1,
int groups = 1
):
Conv2dProblemSize(
N, H, W, C, K, R, S, P, Q,
pad_h, pad_w,
stride_h, stride_w,
dilation_h, dilation_w,
mode, split_k_slices, groups),
D(D), T(T), Z(Z),
pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { }
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
// set *user-defined* output size and sets Z, P, and Q (include all data members in ctor)
CUTLASS_HOST_DEVICE
Conv3dProblemSize(
cutlass::Tensor5DCoord input_size, // NDHWC
cutlass::Tensor5DCoord filter_size, // KTRSC
Coord3D padding, // pad_d, pad_h, pad_w
Coord3D stride, // stride_d, stride_h, stride_w
Coord3D dilation, // dilation_d, dilation_h, dilation_w
cutlass::Tensor5DCoord output_size, // NZPQK
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
int split_k_slices = 1,
int groups = 1
):
Conv2dProblemSize(
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
{padding[1], padding[1], padding[2], padding[2]},
{stride[1], stride[2]},
{dilation[1], dilation[2]},
{output_size.n(), output_size.h(), output_size.w(), output_size.c()},
mode, split_k_slices, groups),
D(input_size.d()), T(filter_size.d()), Z(output_size.d()),
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { }
/// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D
// *computes* output size and sets Z, P and Q (include all data members in ctor)
CUTLASS_HOST_DEVICE
Conv3dProblemSize(
cutlass::Tensor5DCoord input_size, // NDHWC
cutlass::Tensor5DCoord filter_size, // KTRSC
Coord3D padding, // pad_d, pad_h, pad_w
Coord3D stride, // stride_d, stride_h, stride_w
Coord3D dilation, // dilation_d, dilation_h, dilation_w
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
int split_k_slices = 1,
int groups = 1
):
Conv2dProblemSize(
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
{padding[1], padding[1], padding[2], padding[2]},
{stride[1], stride[2]},
{dilation[1], dilation[2]},
mode, split_k_slices, groups),
D(input_size.d()), T(filter_size.d()),
pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0])
{
// set output Z
Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1;
}
/// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D
// *computes* output size and sets Z, P and Q (include all data members in ctor)
CUTLASS_HOST_DEVICE
Conv3dProblemSize(
cutlass::Tensor5DCoord input_size, // NDHWC
cutlass::Tensor5DCoord filter_size, // KTRSC
CUTLASS_STL_NAMESPACE::tuple<Coord3D, Coord3D> padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q
Coord3D stride, // stride_d, stride_h, stride_w
Coord3D dilation, // dilation_d, dilation_h, dilation_w
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation,
int split_k_slices = 1,
int groups = 1
):
Conv2dProblemSize(
{input_size.n(), input_size.h(), input_size.w(), input_size.c()},
{filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()},
{CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1],
CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]},
{stride[1], stride[2]},
{dilation[1], dilation[2]},
mode, split_k_slices, groups),
D(input_size.d()), T(filter_size.d()),
pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0])
{
// set output Z
Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1;
}
/// Equality operator (ignores mode and split_k_slice)
CUTLASS_HOST_DEVICE
bool operator==(Conv3dProblemSize const &conv) const {
return (
(N == conv.N) && (D == conv.D) && (H == conv.H) && (W == conv.W) && (C == conv.C) &&
(K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) &&
(Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) &&
(pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) &&
(stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) &&
(dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w)
);
}
/// Inequality operator
CUTLASS_HOST_DEVICE
bool operator!=(Conv3dProblemSize const &rhs) const {
return !(*this == rhs);
}
// Reset covolution mode in the problem
CUTLASS_HOST_DEVICE
Conv3dProblemSize reset_mode(cutlass::conv::Mode mode_) {
Conv3dProblemSize tmp(*this);
tmp.mode = mode_;
return tmp;
}
// Reset covolution mode in the problem
CUTLASS_HOST_DEVICE
Conv3dProblemSize reset_split_k_slices(int split_k_slices_) {
Conv3dProblemSize tmp(*this);
tmp.split_k_slices = split_k_slices_;
return tmp;
}
/// Returns activation extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord activation_extent() const {
return cutlass::Tensor5DCoord ({N, D, H, W, C});
}
/// Returns filter extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const {
return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K})
: cutlass::Tensor5DCoord ({K, T, R, S, C});
}
/// Returns output extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord output_extent() const {
return cutlass::Tensor5DCoord ({N, Z, P, Q, K});
}
/// Returns activation size in number of elements
CUTLASS_HOST_DEVICE
int64_t activation_size() const {
return (N * D * H * W * C);
}
/// Returns filter size in number of elements
CUTLASS_HOST_DEVICE
int64_t filter_size() const {
return (K * T * R * S * C);
}
/// Returns output size in number of elements
CUTLASS_HOST_DEVICE
int64_t output_size() const {
return (N * Z * P * Q * K);
}
/// Returns padding as Coord3D
CUTLASS_HOST_DEVICE
Coord3D padding() const {
return Coord3D ({pad_d, pad_h, pad_w});
}
/// Returns stride as MatrixCoord
CUTLASS_HOST_DEVICE
Coord3D stride() const {
return Coord3D ({stride_d, stride_h, stride_w});
}
/// Returns dilation as MatrixCoord
CUTLASS_HOST_DEVICE
Coord3D dilation() const {
return Coord3D ({dilation_d, dilation_h, dilation_w});
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// ImplicitGemm helper functions //
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Determine the problem size of the implicit GEMM operation
CUTLASS_HOST_DEVICE
cutlass::gemm::GemmCoord implicit_gemm_problem_size(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
// Compute problem size
switch (conv_operator) {
case Operator::kFprop:
return gemm::GemmCoord(
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q,
problem_size.K,
problem_size.T * problem_size.R * problem_size.S * problem_size.C
);
case Operator::kDeconv:
case Operator::kDgrad:
return gemm::GemmCoord(
problem_size.N * problem_size.D * problem_size.H * problem_size.W,
problem_size.C,
problem_size.T * problem_size.R * problem_size.S * problem_size.K
);
case Operator::kWgrad:
return gemm::GemmCoord(
problem_size.K,
problem_size.T * problem_size.R * problem_size.S * problem_size.C,
problem_size.N * problem_size.Z * problem_size.P * problem_size.Q
);
default:
break;
}
return gemm::GemmCoord();
}
// Determine the number of gemm_k iterations for conv2d problem using implicit gemm algorithm
CUTLASS_HOST_DEVICE
int implicit_gemm_k_iterations(
Operator conv_operator,
int threadblock_K,
Conv3dProblemSize const &problem_size,
IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic,
GroupMode group_mode = GroupMode::kNone,
int threadblock_N = 0) {
int iterations = 0;
int elements_per_split_k_slice = 0;
if (group_mode == GroupMode::kNone) {
switch (conv_operator) {
case Operator::kFprop:
elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kDeconv:
case Operator::kDgrad:
elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K);
break;
case Operator::kWgrad:
elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices;
iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K;
break;
default:
break;
}
} else if (group_mode == GroupMode::kDepthwise) {
int channels_per_cta = threadblock_N;
if (algorithm == IteratorAlgorithm::kAnalytic) {
switch (conv_operator) {
case Operator::kFprop:
iterations = problem_size.T * problem_size.R * problem_size.S *
((channels_per_cta + threadblock_K - 1) / threadblock_K);
break;
default:
break;
}
}
}
return iterations;
}
////////////////////////////////////////////////////////////////////////////////
// Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output)
////////////////////////////////////////////////////////////////////////////////
/// Returns ImplicitGemm tensor A extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.activation_extent();
case cutlass::conv::Operator::kDeconv:
case cutlass::conv::Operator::kDgrad: return problem_size.output_extent();
case cutlass::conv::Operator::kWgrad: return problem_size.output_extent();
default : break;
}
return cutlass::Tensor5DCoord();
}
/// Returns ImplicitGemm tensor B extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.filter_extent();
case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true);
case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent();
case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent();
default : break;
}
return cutlass::Tensor5DCoord();
}
/// Returns ImplicitGemm tensor C extent as Tensor5DCoord
CUTLASS_HOST_DEVICE
cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.output_extent();
case cutlass::conv::Operator::kDeconv:
case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent();
case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent();
default : break;
}
return cutlass::Tensor5DCoord();
}
/// Returns ImplicitGemm tensor A size in number of elements
CUTLASS_HOST_DEVICE
int64_t implicit_gemm_tensor_a_size(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.activation_size();
case cutlass::conv::Operator::kDeconv:
case cutlass::conv::Operator::kDgrad: return problem_size.output_size();
case cutlass::conv::Operator::kWgrad: return problem_size.output_size();
default : break;
}
return 0;
}
/// Returns ImplicitGemm tensor B size in number of elements
CUTLASS_HOST_DEVICE
int64_t implicit_gemm_tensor_b_size(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.filter_size();
case cutlass::conv::Operator::kDeconv:
case cutlass::conv::Operator::kDgrad: return problem_size.filter_size();
case cutlass::conv::Operator::kWgrad: return problem_size.activation_size();
default : break;
}
return 0;
}
/// Returns ImplicitGemm tensor C size in number of elements
CUTLASS_HOST_DEVICE
int64_t implicit_gemm_tensor_c_size(
Operator conv_operator,
Conv3dProblemSize const &problem_size) {
switch (conv_operator) {
case cutlass::conv::Operator::kFprop: return problem_size.output_size();
case cutlass::conv::Operator::kDeconv:
case cutlass::conv::Operator::kDgrad: return problem_size.activation_size();
case cutlass::conv::Operator::kWgrad: return problem_size.filter_size();
default : break;
}
return 0;
}
} // namespace conv
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////