cutlass/include/cutlass/conv/convnd_problem_shape.hpp
2024-03-19 17:51:04 -04:00

575 lines
22 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief This file contains definitions and utility functions for describing convolution problem shapes.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/tensor_coord.h"
#include "cutlass/conv/convolution.h"
#include "cute/container/array.hpp"
#if ! defined(__CUDACC_RTC__)
#include <initializer_list>
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::conv {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion.
// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK)
// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types.
template <
conv::Operator ConvOp_,
int NumSpatialDimensions
>
struct ConvProblemShape {
//
// Alias types for members
//
static constexpr int RankS = NumSpatialDimensions;
static constexpr int RankT = NumSpatialDimensions + 2;
static constexpr conv::Operator ConvOp = ConvOp_;
using SpatialExtent = cute::array<int, RankS>;
using TensorExtent = cute::array<int, RankT>;
using TensorStride = cute::array<int64_t, RankT>;
using ShapePadding = SpatialExtent;
using TraversalStride = SpatialExtent;
using ShapeDilation = SpatialExtent;
using Corner = SpatialExtent;
//
// Members
//
cutlass::conv::Mode mode{};
TensorExtent shape_A{};
TensorStride stride_A{};
TensorExtent shape_B{};
TensorStride stride_B{};
TensorExtent shape_C{};
TensorStride stride_C{};
// asymmetric padding, both upper and lower padding must be >= 0
ShapePadding lower_padding{};
ShapePadding upper_padding{};
TraversalStride traversal_stride{};
ShapeDilation dilation{};
int groups = 1;
//
// Methods
//
ConvProblemShape() = default;
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
ConvProblemShape(
conv::Mode mode, // convolution/cross-correlation
TensorExtent shape_act, // [n,d,h,w,c]
TensorStride stride_act, // [n,d,h,w,c]
TensorExtent shape_flt, // [k,t,r,s,c]
TensorStride stride_flt, // [k,t,r,s,c]
ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
TraversalStride tstride, // [stride_d, stride_h, stride_w]
ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
int groups)
: mode(mode)
, lower_padding(lower_padding)
, upper_padding(upper_padding)
, traversal_stride(tstride)
, dilation(dilation)
, groups(groups) {
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
}
// Allow user input of xformed activation stride to support non-packed strides.
ConvProblemShape(
conv::Mode mode, // convolution/cross-correlation
TensorExtent shape_act, // [n,d,h,w,c]
TensorStride stride_act, // [n,d,h,w,c]
TensorExtent shape_flt, // [k,t,r,s,c]
TensorStride stride_flt, // [k,t,r,s,c]
TensorStride stride_xformed_act, // [n,z,p,q,k]
ShapePadding lower_padding, // [pad_d, pad_h, pad_w]
ShapePadding upper_padding, // [pad_d, pad_h, pad_w]
TraversalStride tstride, // [stride_d, stride_h, stride_w]
ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w]
int groups)
: mode(mode)
, lower_padding(lower_padding)
, upper_padding(upper_padding)
, traversal_stride(tstride)
, dilation(dilation)
, groups(groups) {
CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
auto stride_act_packed = packed_stride_right_major(shape_act);
auto stride_flt_packed = packed_stride_right_major(shape_flt);
auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < RankT - 1; ++i) {
CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
}
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
}
// Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order.
ConvProblemShape(
conv::Mode mode,
TensorExtent shape_act,
TensorExtent shape_flt,
ShapePadding lower_padding,
ShapePadding upper_padding,
TraversalStride tstride,
ShapeDilation dilation,
int groups)
: ConvProblemShape(
mode,
shape_act,
packed_stride_right_major(shape_act),
shape_flt,
packed_stride_right_major(shape_flt),
lower_padding,
upper_padding,
tstride,
dilation,
groups) {
}
#if ! defined(__CUDACC_RTC__)
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
ConvProblemShape(
conv::Mode mode,
std::initializer_list<int> shape_act_,
std::initializer_list<int64_t> stride_act_,
std::initializer_list<int> shape_flt_,
std::initializer_list<int64_t> stride_flt_,
std::initializer_list<int> lower_padding_,
std::initializer_list<int> upper_padding_,
std::initializer_list<int> traversal_stride_,
std::initializer_list<int> dilation_,
int groups)
: mode(mode)
, groups(groups) {
TensorExtent shape_act{};
TensorStride stride_act{};
TensorExtent shape_flt{};
TensorStride stride_flt{};
assert(shape_act_.size() == shape_act.size());
assert(stride_act_.size() == stride_act.size());
assert(shape_flt_.size() == shape_flt.size());
assert(stride_flt_.size() == stride_flt.size());
assert(lower_padding_.size() == lower_padding.size());
assert(upper_padding_.size() == upper_padding.size());
assert(traversal_stride_.size() == traversal_stride.size());
assert(dilation_.size() == dilation.size());
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
}
// Allow user input of xformed activation stride to support non-packed strides.
ConvProblemShape(
conv::Mode mode,
std::initializer_list<int> shape_act_,
std::initializer_list<int64_t> stride_act_,
std::initializer_list<int> shape_flt_,
std::initializer_list<int64_t> stride_flt_,
std::initializer_list<int64_t> stride_xformed_act_,
std::initializer_list<int> lower_padding_,
std::initializer_list<int> upper_padding_,
std::initializer_list<int> traversal_stride_,
std::initializer_list<int> dilation_,
int groups)
: mode(mode)
, groups(groups) {
TensorExtent shape_act{};
TensorStride stride_act{};
TensorExtent shape_flt{};
TensorStride stride_flt{};
TensorStride stride_xformed_act{};
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin());
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin());
std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin());
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
CUTLASS_ASSERT(stride_act[RankT - 1] == 1);
CUTLASS_ASSERT(stride_flt[RankT - 1] == 1);
CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1);
auto stride_act_packed = packed_stride_right_major(shape_act);
auto stride_flt_packed = packed_stride_right_major(shape_flt);
auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt);
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < RankT - 1; ++i) {
CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]);
CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]);
CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]);
}
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
}
// Constructor accepts user facing arguments and computes to stores the corners as its internal state
ConvProblemShape(
conv::Mode mode,
std::initializer_list<int> shape_act_,
std::initializer_list<int> shape_flt_,
std::initializer_list<int> lower_padding_,
std::initializer_list<int> upper_padding_,
std::initializer_list<int> traversal_stride_,
std::initializer_list<int> dilation_,
int groups)
: mode(mode)
, groups(groups) {
TensorExtent shape_act{};
TensorStride stride_act{};
TensorExtent shape_flt{};
TensorStride stride_flt{};
assert(shape_act_.size() == shape_act.size());
assert(shape_flt_.size() == shape_flt.size());
assert(lower_padding_.size() == lower_padding.size());
assert(upper_padding_.size() == upper_padding.size());
assert(traversal_stride_.size() == traversal_stride.size());
assert(dilation_.size() == dilation.size());
std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin());
std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin());
std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin());
std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin());
std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin());
std::copy(dilation_.begin(), dilation_.end(), dilation.begin());
stride_act = packed_stride_right_major(shape_act);
stride_flt = packed_stride_right_major(shape_flt);
auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt);
set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act);
}
#endif // not defined(__CUDACC_RTC__)
// Set shape and stride of tensor A/B/C according to following table:
// | | Fprop | Dgrad | Wgrad |
// | ------ | ------ | ------ | ------|
// | ShapeA | NDHWC | NZPQK | NZPQK |
// | ShapeB | KTRSC | KTRSC | NDHWC |
// | ShapeC | NZPQK | NDHWC | KTRSC |
//
CUTLASS_HOST_DEVICE
constexpr void
set_shape_stride_ABC(
TensorExtent shape_act,
TensorStride stride_act,
TensorExtent shape_flt,
TensorStride stride_flt,
TensorExtent shape_xformed_act,
TensorStride stride_xformed_act) {
if constexpr (ConvOp == cutlass::conv::Operator::kFprop) {
shape_A = shape_act;
stride_A = stride_act;
shape_B = shape_flt;
stride_B = stride_flt;
shape_C = shape_xformed_act;
stride_C = stride_xformed_act;
}
else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) {
shape_A = shape_xformed_act;
stride_A = stride_xformed_act;
shape_B = shape_flt;
stride_B = stride_flt;
shape_C = shape_act;
stride_C = stride_act;
}
else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) {
shape_A = shape_xformed_act;
stride_A = stride_xformed_act;
shape_B = shape_act;
stride_B = stride_act;
shape_C = shape_flt;
stride_C = stride_flt;
}
}
// Get problem shape MNK according to following table:
// | | Fprop | Dgrad | Wgrad |
// | ---- | --------- | -------- | -------- |
// | Shape_M | (Q,P,Z,N) | (W,H,D,N) | (K) |
// | Shape_N | (K) | (C) | (C,S,R,T) |
// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) |
CUTLASS_HOST_DEVICE
constexpr auto
get_transformed_problem_shape_MNK() const {
using cute::insert;
using cute::make_shape;
using cute::reverse;
using cute::take;
if constexpr (ConvOp == conv::Operator::kWgrad) {
auto M_xformed = shape_C[0];
auto N_xformed = reverse(take<1, RankT>(shape_C));
auto K_xformed = reverse(take<0, RankT - 1>(shape_A));
return make_shape(M_xformed, N_xformed, K_xformed);
}
else if constexpr (ConvOp == conv::Operator::kFprop){
auto M_xformed = reverse(take<0, RankT - 1>(shape_C));
auto N_xformed = shape_C[RankT - 1];
auto K_xformed = reverse(take<1, RankT>(shape_B));
return make_shape(M_xformed, N_xformed, K_xformed);
}
else if constexpr (ConvOp == conv::Operator::kDgrad) {
auto M_xformed = reverse(take<0,RankT - 1>(shape_C));
auto N_xformed = shape_C[RankT - 1];
// shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T]
auto K_xformed = insert<0>(
(reverse(take<1,RankT - 1>(shape_B))),
shape_B[0]);
return make_shape(M_xformed, N_xformed, K_xformed);
}
}
// Get A extents.
// fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C))
// wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N))
// dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K))
CUTLASS_HOST_DEVICE
constexpr auto
get_shape_A() const {
using cute::make_shape;
using cute::take;
if constexpr (ConvOp == conv::Operator::kFprop ||
ConvOp == conv::Operator::kDgrad) {
return make_shape(
cute::reverse(take<0, RankT - 1>(shape_A)),
shape_A[RankT - 1]);
}
// For wgrad kernel, we need to linearize NZPQ for tensor A
else if constexpr (ConvOp == conv::Operator::kWgrad) {
return make_shape(
shape_A[RankT - 1],
cute::product(take<0, RankT - 1>(shape_A)));
}
}
// Get B extents.
// fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T))
// wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N))
// dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T))
CUTLASS_HOST_DEVICE
constexpr auto
get_shape_B() const {
using cute::make_shape;
using cute::reverse;
using cute::take;
if constexpr (ConvOp == conv::Operator::kFprop) {
return make_shape(
shape_B[0],
reverse(take<1, RankT>(shape_B)));
}
else if constexpr (ConvOp == conv::Operator::kWgrad) {
return make_shape(
shape_B[RankT - 1],
reverse(take<0, RankT - 1>(shape_B)));
}
else if constexpr (ConvOp == conv::Operator::kDgrad) {
// shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)]
return make_shape(
shape_B[RankT - 1],
cute::insert<0>(
reverse(take<1, RankT - 1>(shape_B)),
shape_B[0]));
}
}
// Static method that returns the canonical strides of tensors (layouts are right major and compact)
CUTLASS_HOST_DEVICE
static constexpr TensorStride
packed_stride_right_major(TensorExtent const& extents) {
TensorStride strides{};
strides[RankT-1] = 1;
cute::for_each(cute::make_rseq<RankT-1>{}, [&](auto i) {
strides[i] = extents[i+1] * strides[i+1];
});
return strides;
}
// Static method that returns the packed logical size of any TensorExtent
CUTLASS_HOST_DEVICE
static constexpr size_t
size(TensorExtent const& extents) {
size_t size = 1;
cute::for_each(cute::make_seq<RankT>{}, [&](auto i) {
size *= extents[i];
});
return size;
}
CUTLASS_HOST_DEVICE
constexpr size_t
size_A() const {
return shape_A[0] * stride_A[0];
}
CUTLASS_HOST_DEVICE
constexpr size_t
size_B() const {
return shape_B[0] * stride_B[0];
}
CUTLASS_HOST_DEVICE
constexpr size_t
size_C() const {
return shape_C[0] * stride_C[0];
}
// Equality operator
CUTLASS_HOST_DEVICE
bool operator==(ConvProblemShape<ConvOp, NumSpatialDimensions> const& rhs) const {
using cute::for_each;
using cute::make_seq;
bool is_equal = true;
// Compare all tensor extents
for_each(make_seq<RankT>{}, [&](auto i) {
is_equal = is_equal
&& (shape_A[i] == rhs.shape_A[i])
&& (shape_B[i] == rhs.shape_B[i]);
});
// Compare all spatial extents
for_each(make_seq<RankS>{}, [&](auto i) {
is_equal = is_equal
&& (lower_padding[i] == rhs.lower_padding[i])
&& (upper_padding[i] == rhs.upper_padding[i])
&& (traversal_stride[i] == rhs.traversal_stride[i])
&& (dilation[i] == rhs.dilation[i]);
});
return is_equal;
}
/// Inequality operator
CUTLASS_HOST_DEVICE
bool operator!=(ConvProblemShape<ConvOp, NumSpatialDimensions> const &rhs) const {
return !(*this == rhs);
}
private:
CUTLASS_HOST_DEVICE
constexpr auto
calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) {
TensorExtent shape_xformed_act{};
// calculate n,z,p,q,k.
// a helper lambda to compute a single spatial extent of the nzpqk tensor
auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) {
return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride;
};
shape_xformed_act[0] = shape_act[0]; // Activation N extent
cute::for_each(cute::make_seq<RankS>{}, [&](auto i) {
shape_xformed_act[i+1] = nzpqk_extent(
shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]);
});
shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent
TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act);
return cute::make_tuple(shape_xformed_act, stride_xformed_act);
}
};
template<
conv::Operator ConvOp,
int SpatialDim
>
void print(ConvProblemShape<ConvOp, SpatialDim> const& problem) {
printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n",
SpatialDim, int(ConvOp));
printf("\tTensorA: ");
cute::print(problem.shape_A); printf(":");
cute::print(problem.stride_A); printf("\n");
printf("\tTensorB: ");
cute::print(problem.shape_B); printf(":");
cute::print(problem.stride_B); printf("\n");
printf("\tTensorC: ");
cute::print(problem.shape_C); printf(":");
cute::print(problem.stride_C); printf("\n");
printf("\tLower padding: "); print(problem.lower_padding); printf("\n");
printf("\tUpper padding: "); print(problem.upper_padding); printf("\n");
printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n");
printf("\tDilation: "); print(problem.dilation); printf("\n");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::conv
////////////////////////////////////////////////////////////////////////////////////////////////////