cutlass/include/cutlass/coord.h
HouQiming 96a11a1ef3
Removed trivial copy constructors on parameter classes to enable devi… (#366)
* Removed trivial copy constructors on parameter classes to enable device-side launch of CUTLASS kernels

* Added SFINAE to the `TensorRef(NonConstTensorRef const&)` constructor to avoid making it a copy-constructor for device code

* std => platform

* fix affine2

* really fix affine2

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2022-02-28 21:34:02 -05:00

466 lines
11 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright notice, this list of
* conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright notice, this list of
* conditions and the following disclaimer in the documentation and/or other materials
* provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
* to endorse or promote products derived from this software without specific prior written
* permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Coord is a coordinate of arbitrary rank into a tensor or matrix
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cstdint>
#else
#include <stdint.h>
#endif
#include "cutlass/cutlass.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically-sized array specifying Coords within a tensor
template <
int Rank_, ///< Logical rank of coordinate
typename Index_ = int, ///< Index type used for each dimension
typename LongIndex_ = int64_t ///< Long index type used for linear offsets
>
struct Coord {
public:
//
// Type and constant definitions
//
/// Number of elements in Coord
static int const kRank = Rank_;
/// Index type used to store elements
using Index = Index_;
/// Type used to represent linear offsets
using LongIndex = LongIndex_;
private:
//
// Data members
//
/// Indices
Index idx[kRank];
public:
//
// Methods
//
/// Default ctor initializes uniformly
CUTLASS_HOST_DEVICE
explicit Coord(Index value = Index(0)) {
for (int i = 0; i < kRank; ++i) {
idx[i] = value;
}
}
/// Constructs from an array of integers
CUTLASS_HOST_DEVICE
Coord(Index const (&_idx)[kRank]) {
for (int i = 0; i < kRank; ++i) {
idx[i] = _idx[i];
}
}
/// Returns a slice of the Coord which may be larger or smaller in rank
/// than this.
template <int Slice>
CUTLASS_HOST_DEVICE
Coord<Slice> slice(int start = 0, Index identity = 0) const {
Coord<Slice> result;
for (int i = 0; i < Slice; ++i) {
if (i + start < kRank) {
result[i] = idx[i + start];
}
else {
result[i] = identity;
}
}
return result;
}
/// Returns the index of the dimension with least value
CUTLASS_HOST_DEVICE
int min_dim_index() const {
int i = 0;
for (int j = 1; j < kRank; ++j) {
if (idx[j] < idx[i]) {
i = j;
}
}
return i;
}
/// Returns the index of the dimension with greatest value
CUTLASS_HOST_DEVICE
int max_dim_index() const {
int i = 0;
for (int j = 1; j < kRank; ++j) {
if (idx[j] > idx[i]) {
i = j;
}
}
return i;
}
/// Returns true if Coord is non-zero.
CUTLASS_HOST_DEVICE
explicit operator bool() const {
for (int i = 0; i < kRank; ++i) {
if (idx[i]) {
return true;
}
}
return false;
}
/// Returns true if Coord is uniformly zero.
CUTLASS_HOST_DEVICE
bool operator!() const {
for (int i = 0; i < kRank; ++i) {
if (idx[i]) {
return false;
}
}
return true;
}
/// Element-wise addition
CUTLASS_HOST_DEVICE
Coord operator+(Coord const& b) const {
Coord c;
for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] + b.idx[i];
}
return c;
}
/// Element-wise subtraction
CUTLASS_HOST_DEVICE
Coord operator-(Coord const& b) const {
Coord c;
for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] - b.idx[i];
}
return c;
}
/// Element-wise multiplication
CUTLASS_HOST_DEVICE
Coord operator*(Coord const& b) const {
Coord c;
for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] * b.idx[i];
}
return c;
}
/// Element-wise division
CUTLASS_HOST_DEVICE
Coord operator/(Coord const& b) const {
Coord c;
for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] / b.idx[i];
}
return c;
}
/// In-place addition
CUTLASS_HOST_DEVICE
Coord& operator+=(Coord const& b) {
for (int i = 0; i < kRank; ++i) {
idx[i] += b.idx[i];
}
return *this;
}
/// In-place subtraction
CUTLASS_HOST_DEVICE
Coord& operator-=(Coord const& b) {
for (int i = 0; i < kRank; ++i) {
idx[i] -= b.idx[i];
}
return *this;
}
/// In-place multiplication
CUTLASS_HOST_DEVICE
Coord& operator*=(Coord const& b) {
for (int i = 0; i < kRank; ++i) {
idx[i] *= b.idx[i];
}
return *this;
}
/// In-place division
CUTLASS_HOST_DEVICE
Coord& operator/=(Coord const& b) {
for (int i = 0; i < kRank; ++i) {
idx[i] /= b.idx[i];
}
return *this;
}
/// Member access operator
CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; }
/// Member access operator
CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; }
/// Computes the dot product with anotherCoord object
CUTLASS_HOST_DEVICE
LongIndex dot(Coord const& b, LongIndex sum = LongIndex(0)) const {
for (int i = 0; i < kRank; ++i) {
sum += idx[i] * b.idx[i];
}
return sum;
}
/// Gets the index of a given Coord element
template <int Dim>
CUTLASS_HOST_DEVICE Index& at() {
return idx[Dim];
}
/// Access via index; may limit unrolling potential
CUTLASS_HOST_DEVICE
Index& at(int dim) { return idx[dim]; }
/// Gets the index of a given Coord element
template <int Dim>
CUTLASS_HOST_DEVICE Index const& at() const {
return idx[Dim];
}
/// Access via index; may limit unrolling potential
CUTLASS_HOST_DEVICE
Index const& at(int dim) const { return idx[dim]; }
/// Determines if two Coord<> objects are equal
CUTLASS_HOST_DEVICE
bool operator==(Coord const& b) const {
bool equal = true;
for (int i = 0; equal && i < kRank; ++i) {
equal = (idx[i] == b.idx[i]);
}
return equal;
}
/// Not equal
CUTLASS_HOST_DEVICE
bool operator!=(Coord const& b) const { return !(*this == b); }
/// Clamps a coordinate to a range specified by maximum and minimum values
CUTLASS_HOST_DEVICE
Coord& clamp(Coord const& max, Coord const& min = Coord()) {
for (int i = 0; i < kRank; ++i) {
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
}
return *this;
}
/// Returns the sum of all elements
CUTLASS_HOST_DEVICE
Index sum() const {
Index sum_(idx[0]);
for (int i = 1; i < kRank; ++i) {
sum_ += idx[i];
}
return sum_;
}
/// Returns the product of all elements
CUTLASS_HOST_DEVICE
LongIndex product() const {
LongIndex product_(idx[0]);
for (int i = 1; i < kRank; ++i) {
product_ *= idx[i];
}
return product_;
}
/// Less than operator
CUTLASS_HOST_DEVICE
bool operator<(Coord const &b) const {
for (int i = 0; i < kRank; ++i) {
if (!(idx[i] < b[i])) {
return false;
}
}
return true;
}
/// Less than or equals operator
CUTLASS_HOST_DEVICE
bool operator<=(Coord const &b) const {
for (int i = 0; i < kRank; ++i) {
if (!(idx[i] <= b[i])) {
return false;
}
}
return true;
}
/// Greater than operator
CUTLASS_HOST_DEVICE
bool operator>(Coord const &b) const {
return !(*this <= b);
}
/// Greater than or equals operator
CUTLASS_HOST_DEVICE
bool operator>=(Coord const &b) const {
return !(*this < b);
}
};
} // namespace cutlass
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
/// Scalar multiplication
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(Index s, Coord<Rank, Index> coord) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}
/// Scalar multiplication
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, Index s) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}
/// Scalar division
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator/(Index s, Coord<Rank, Index> coord) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] = s / coord[i];
}
return coord;
}
/// Scalar division
template <int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator/(Coord<Rank, Index> coord, Index s) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] /= s;
}
return coord;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Integer-valued make_Coord
//
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper to make a 2-element coordinate
template <typename T>
CUTLASS_HOST_DEVICE
Coord<1, T> make_Coord(T _0) {
T values[1] = {_0};
return Coord<1, T>(values);
}
/// Helper to make a 2-element coordinate
template <typename T>
CUTLASS_HOST_DEVICE
Coord<2, T> make_Coord(T _0, T _1) {
T values[2] = {_0, _1};
return Coord<2, T>(values);
}
/// Helper to make a 3-element coordinate
template <typename T>
CUTLASS_HOST_DEVICE
Coord<3, T> make_Coord(T _0, T _1, T _2) {
T values[3] = {_0, _1, _2};
return Coord<3, T>(values);
}
/// Helper to make a 4-element coordinate
template <typename T>
CUTLASS_HOST_DEVICE
Coord<4, T> make_Coord(T _0, T _1, T _2, T _3) {
T values[4] = {_0, _1, _2, _3};
return Coord<4, T>(values);
}
/// Helper to make a 5-element coordinate
template <typename T>
CUTLASS_HOST_DEVICE
Coord<5, T> make_Coord(T _0, T _1, T _2, T _3, T _4) {
T values[5] = {_0, _1, _2, _3, _4};
return Coord<5, T>(values);
}
/// Helper to make a 1-element coordinate
template <int N, typename T>
CUTLASS_HOST_DEVICE
Coord<N, T>make_Coord_with_padding(T _0) {
Coord<N, T> coord;
CUTLASS_PRAGMA_UNROLL
for (int i = N - 1; i > 0; --i) {
coord[i] = 0;
}
coord[0] = _0;
return coord;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass