2018-05-17 02:44:56 +08:00
|
|
|
/***************************************************************************************************
|
|
|
|
* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
*
|
|
|
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
|
|
|
* provided that the following conditions are met:
|
|
|
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer.
|
|
|
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
|
|
|
* conditions and the following disclaimer in the documentation and/or other materials
|
|
|
|
* provided with the distribution.
|
|
|
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
|
|
|
* to endorse or promote products derived from this software without specific prior written
|
|
|
|
* permission.
|
|
|
|
*
|
|
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
|
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*
|
|
|
|
**************************************************************************************************/
|
|
|
|
/*! \file
|
2018-09-19 07:58:03 +08:00
|
|
|
\brief Host-side implementation of basic tensor operations.
|
|
|
|
|
|
|
|
See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details.
|
2018-05-17 02:44:56 +08:00
|
|
|
*/
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
#include "cutlass/cutlass.h"
|
|
|
|
#include "cutlass/tensor_view.h"
|
|
|
|
#include "tools/util/type_traits.h"
|
2018-05-17 02:44:56 +08:00
|
|
|
|
|
|
|
namespace cutlass {
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
template <
|
|
|
|
/// Data type of element stored within tensor
|
|
|
|
typename Storage_,
|
|
|
|
/// Rank of logical tensor
|
|
|
|
int Rank_ = 4,
|
|
|
|
/// Maps a Coord<Rank_> in the logical tensor index space to the internal n-D array
|
|
|
|
typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
|
|
|
|
/// Rank of internal n-D array
|
|
|
|
int StorageRank_ = Rank_,
|
|
|
|
/// Index type used for coordinates
|
|
|
|
typename Index_ = int,
|
|
|
|
/// Index type used for offsets and pointer differences
|
|
|
|
typename LongIndex_ = long long
|
|
|
|
>
|
|
|
|
class HostTensorView :
|
|
|
|
public TensorView<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
|
|
|
|
public:
|
|
|
|
/// Base class
|
|
|
|
typedef TensorView<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> Base;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Storage type
|
|
|
|
typedef typename Base::Storage Storage;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Alias for underlying TensorRef
|
|
|
|
typedef typename Base::TensorRef TensorRef;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Index type
|
|
|
|
typedef typename Base::Index Index;
|
|
|
|
|
|
|
|
/// Coordinate in logical tensor space
|
|
|
|
typedef typename TensorRef::TensorCoord TensorCoord;
|
|
|
|
|
|
|
|
/// Coordinate in storage n-D array
|
|
|
|
typedef typename TensorRef::StorageCoord StorageCoord;
|
|
|
|
|
|
|
|
/// Stride vector in storage coordinate space
|
|
|
|
/// Least significant stride is = 1 and not stored
|
|
|
|
typedef typename TensorRef::StrideVector StrideVector;
|
|
|
|
|
|
|
|
/// Long index type for pointer offsets
|
|
|
|
typedef typename Base::LongIndex LongIndex;
|
|
|
|
|
|
|
|
/// Rank of tensor index space
|
|
|
|
static int const kRank = Base::kRank;
|
|
|
|
|
|
|
|
//
|
|
|
|
// Definitions included for backwards compatibility - These will be remmoved
|
|
|
|
// in the next major release.
|
|
|
|
//
|
2018-05-17 02:44:56 +08:00
|
|
|
|
|
|
|
/// Base class
|
2018-09-19 07:58:03 +08:00
|
|
|
typedef Base TensorView_t;
|
|
|
|
|
|
|
|
//
|
|
|
|
// These definitions are meaningful for rank=4 tensors.
|
|
|
|
//
|
2018-05-17 02:44:56 +08:00
|
|
|
|
|
|
|
/// Convention: depth is the first dimension
|
|
|
|
static int const Dim_D = 0;
|
|
|
|
|
|
|
|
/// Convention: height is the second dimension
|
|
|
|
static int const Dim_H = 1;
|
|
|
|
|
|
|
|
/// Convention: width is the third dimension
|
|
|
|
static int const Dim_W = 2;
|
|
|
|
|
|
|
|
/// Convention: channel is the second dimension
|
|
|
|
static int const Dim_C = 3;
|
|
|
|
|
|
|
|
public:
|
2018-09-19 07:58:03 +08:00
|
|
|
|
2018-05-17 02:44:56 +08:00
|
|
|
//
|
|
|
|
// Device and Host Methods
|
|
|
|
//
|
|
|
|
|
|
|
|
/// Default constructor
|
|
|
|
HostTensorView() {}
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Helper to construct from pointer, stride, and size
|
|
|
|
HostTensorView(
|
|
|
|
Storage_ *_ptr,
|
|
|
|
StrideVector const &_stride,
|
|
|
|
TensorCoord const& _size
|
|
|
|
) : Base(TensorRef(_ptr, _stride), _size) {}
|
|
|
|
|
|
|
|
/// Helper to construct from pointer, stride, and size
|
|
|
|
HostTensorView(
|
|
|
|
Storage_ *_ptr,
|
|
|
|
StorageCoord const &_stride,
|
|
|
|
TensorCoord const& _size
|
|
|
|
) : Base(TensorRef(_ptr, _stride), _size) {}
|
|
|
|
|
|
|
|
/// Constructs a Tensor_view from a TensorRef and size assuming dense packing
|
|
|
|
HostTensorView(
|
|
|
|
TensorRef const& _ref,
|
|
|
|
TensorCoord const& _size) : Base(_ref, _size) {}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
|
|
|
/// Assigns a tensor view
|
2018-09-19 07:58:03 +08:00
|
|
|
HostTensorView& operator=(Base const& _tensor) {
|
|
|
|
this->reset(_tensor.ref(), _tensor.size());
|
2018-05-17 02:44:56 +08:00
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Returns a TensorView offset by a given amount
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
HostTensorView operator+(TensorCoord const& b) const {
|
|
|
|
HostTensorView result(*this);
|
|
|
|
result.add_pointer_offset(this->offset(b));
|
|
|
|
return result;
|
|
|
|
}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
HostTensorView& operator+=(TensorCoord const& b) {
|
|
|
|
this->add_pointer_offset(this->offset(b));
|
|
|
|
return *this;
|
|
|
|
}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
HostTensorView operator-(TensorCoord const& b) const {
|
|
|
|
TensorRef result(*this);
|
|
|
|
result.add_pointer_offset(-this->offset(b));
|
|
|
|
return result;
|
|
|
|
}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
|
|
HostTensorView& operator-=(TensorCoord const& b) {
|
|
|
|
this->add_pointer_offset(-this->offset(b));
|
|
|
|
return *this;
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
/// Recurses through all dimensions and applies a unary operation in place
|
|
|
|
template <typename F>
|
2018-09-19 07:58:03 +08:00
|
|
|
void elementwise_in_place(F& op, int dim = 0, TensorCoord const &start_coord = TensorCoord()) {
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord coord(start_coord);
|
|
|
|
for (int idx = 0; idx < this->size(dim); ++idx) {
|
|
|
|
coord[dim] = idx;
|
|
|
|
if (dim < kRank - 1) {
|
|
|
|
elementwise_in_place(op, dim + 1, coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
} else {
|
2018-09-19 07:58:03 +08:00
|
|
|
op(this->at(coord));
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Recurses through all dimensions and applies a unary operator with no arguments
|
|
|
|
template <typename F>
|
2018-09-19 07:58:03 +08:00
|
|
|
void elementwise_stream(F& op, int dim = 0, TensorCoord const &start_coord = TensorCoord()) {
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord coord(start_coord);
|
|
|
|
for (int idx = 0; idx < this->size(dim); ++idx) {
|
|
|
|
coord[dim] = idx;
|
|
|
|
if (dim < kRank - 1) {
|
|
|
|
elementwise_stream(op, dim + 1, coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
} else {
|
2018-09-19 07:58:03 +08:00
|
|
|
this->at(coord) = op();
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
|
|
|
/// coordinate within the tensor as an argument
|
|
|
|
template <typename F>
|
|
|
|
void elementwise_generate(F& op,
|
|
|
|
int dim = 0,
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord const & start_coord = TensorCoord()) {
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord coord(start_coord);
|
|
|
|
for (int idx = 0; idx < this->size(dim); ++idx) {
|
|
|
|
coord[dim] = idx;
|
|
|
|
if (dim < kRank - 1) {
|
|
|
|
elementwise_generate(op, dim + 1, coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
} else {
|
2018-09-19 07:58:03 +08:00
|
|
|
this->at(coord) = op(coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Recurses through all dimensions and applies a unary operator, supplying the logical
|
2018-09-19 07:58:03 +08:00
|
|
|
/// coordinate within the tensor as an argument. Mutable.
|
2018-05-17 02:44:56 +08:00
|
|
|
template <typename F>
|
|
|
|
void elementwise_visit(F& op,
|
|
|
|
int dim = 0,
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord const & start_coord = TensorCoord()) const {
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord coord(start_coord);
|
|
|
|
for (int idx = 0; idx < this->size(dim); ++idx) {
|
|
|
|
coord[dim] = idx;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
if (dim < kRank - 1) {
|
|
|
|
elementwise_visit(op, dim + 1, coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
} else {
|
2018-09-19 07:58:03 +08:00
|
|
|
op(this->at(coord), coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Recurses through all dimensions and applies a binary operation
|
2018-09-19 07:58:03 +08:00
|
|
|
template <typename F, typename SrcTensorView>
|
2018-05-17 02:44:56 +08:00
|
|
|
bool elementwise_in_place(F& op,
|
2018-09-19 07:58:03 +08:00
|
|
|
SrcTensorView const& tensor,
|
2018-05-17 02:44:56 +08:00
|
|
|
int dim = 0,
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord const &start_coord = TensorCoord()) {
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
if (this->size(dim) != tensor.size(dim)) {
|
2018-05-17 02:44:56 +08:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord coord(start_coord);
|
|
|
|
for (int idx = 0; idx < this->size(dim); ++idx) {
|
|
|
|
coord[dim] = idx;
|
|
|
|
if (dim < kRank - 1) {
|
|
|
|
elementwise_in_place(op, tensor, dim + 1, coord);
|
2018-05-17 02:44:56 +08:00
|
|
|
} else {
|
2018-09-19 07:58:03 +08:00
|
|
|
op(this->at(coord), tensor.at(coord));
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename Src>
|
|
|
|
struct LambdaBinaryAddition {
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_& a, Src b) const { a += Storage_(b); }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
template <typename Src>
|
|
|
|
struct LambdaBinarySubtraction {
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_& a, Src b) const { a -= Storage_(b); }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
template <typename Src>
|
|
|
|
struct LambdaBinaryMultiplication {
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_& a, Src b) const { a *= Storage_(b); }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
template <typename Src>
|
|
|
|
struct LambdaBinaryDivision {
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_& a, Src b) const { a /= Storage_(b); }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// Accumulate in place
|
2018-09-19 07:58:03 +08:00
|
|
|
template <typename SrcTensorView>
|
|
|
|
HostTensorView& operator+=(SrcTensorView const& tensor) {
|
|
|
|
LambdaBinaryAddition<typename SrcTensorView::Storage> op;
|
2018-05-17 02:44:56 +08:00
|
|
|
elementwise_in_place(op, tensor);
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Subtract in place
|
2018-09-19 07:58:03 +08:00
|
|
|
template <typename SrcTensorView>
|
|
|
|
HostTensorView& operator-=(SrcTensorView const& tensor) {
|
|
|
|
LambdaBinarySubtraction<typename SrcTensorView::Storage> op;
|
2018-05-17 02:44:56 +08:00
|
|
|
elementwise_in_place(op, tensor);
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Multiply in place
|
2018-09-19 07:58:03 +08:00
|
|
|
template <typename SrcTensorView>
|
|
|
|
HostTensorView& operator*=(SrcTensorView const& tensor) {
|
|
|
|
LambdaBinaryMultiplication<typename SrcTensorView::Storage> op;
|
2018-05-17 02:44:56 +08:00
|
|
|
elementwise_in_place(op, tensor);
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Divide in place
|
2018-09-19 07:58:03 +08:00
|
|
|
template <typename SrcTensorView>
|
|
|
|
HostTensorView& operator/=(SrcTensorView const& tensor) {
|
|
|
|
LambdaBinaryDivision<typename SrcTensorView::Storage> op;
|
2018-05-17 02:44:56 +08:00
|
|
|
elementwise_in_place(op, tensor);
|
|
|
|
|
|
|
|
return *this;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Comparison operator
|
|
|
|
struct EqualsOperator {
|
|
|
|
bool equal;
|
2018-09-19 07:58:03 +08:00
|
|
|
Storage_ eps;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
EqualsOperator(Storage_ _epsilon) : equal(true), eps(_epsilon) {}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_ a, Storage_ b) {
|
|
|
|
if (std::abs(Storage_(a - b)) > eps * std::max(std::abs(a), std::abs(b))) {
|
2018-05-17 02:44:56 +08:00
|
|
|
equal = false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// equality with epsilon tolerance
|
2018-09-19 07:58:03 +08:00
|
|
|
bool equals(Base const& tensor, Storage epsilon) const {
|
2018-05-17 02:44:56 +08:00
|
|
|
EqualsOperator comparison_op(epsilon);
|
|
|
|
bool equal_size = elementwise_in_place(comparison_op, tensor);
|
|
|
|
|
|
|
|
return equal_size && comparison_op.equal;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Compares two values which are smaller or equal to a long long int
|
|
|
|
struct BitEqualsOperator {
|
|
|
|
bool equal;
|
|
|
|
long long eps;
|
|
|
|
uint64_t index;
|
|
|
|
|
|
|
|
BitEqualsOperator(long long _ulps_threshold) : equal(true), eps(_ulps_threshold), index(0) {}
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_ a, Storage_ b) {
|
2018-05-17 02:44:56 +08:00
|
|
|
// convert bits to integers
|
|
|
|
long long bits_a = 0;
|
|
|
|
long long bits_b = 0;
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
*reinterpret_cast<Storage_*>(&bits_a) = TypeTraits<Storage_>::remove_negative_zero(a);
|
|
|
|
*reinterpret_cast<Storage_*>(&bits_b) = TypeTraits<Storage_>::remove_negative_zero(b);
|
2018-05-17 02:44:56 +08:00
|
|
|
|
|
|
|
// compute diff
|
|
|
|
long long ulps = bits_a - bits_b;
|
|
|
|
if (std::abs(ulps) > eps) {
|
|
|
|
equal = false;
|
|
|
|
}
|
|
|
|
index++;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// equality with ulps tolerance
|
2018-09-19 07:58:03 +08:00
|
|
|
bool bit_equals(Base const& tensor, long long ulps_threshold = 0) {
|
2018-05-17 02:44:56 +08:00
|
|
|
BitEqualsOperator comparison_op(ulps_threshold);
|
|
|
|
bool equal_size = elementwise_in_place(comparison_op, tensor);
|
|
|
|
|
|
|
|
return equal_size && comparison_op.equal;
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Fills with random data
|
|
|
|
template <typename Gen>
|
|
|
|
void fill_random(Gen generator) {
|
|
|
|
elementwise_stream(generator);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Procedurally assigns elements
|
|
|
|
template <typename Gen>
|
|
|
|
void generate(Gen generator) {
|
|
|
|
elementwise_generate(generator);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Procedurally visits elements
|
|
|
|
template <typename Gen>
|
|
|
|
void visit(Gen& generator) const {
|
|
|
|
elementwise_visit(generator);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Generator to fill a tensor with the identity matrix
|
|
|
|
struct LambdaFillIdentity {
|
2018-09-19 07:58:03 +08:00
|
|
|
Storage_ operator()(TensorCoord const& coord) {
|
|
|
|
return (coord.at(1) == coord.at(2) ? Storage_(1) : Storage_(0));
|
|
|
|
}
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// initializes with identity
|
|
|
|
void fill_identity() {
|
|
|
|
LambdaFillIdentity op;
|
|
|
|
elementwise_generate(op);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Lambda for fill_linear()
|
|
|
|
struct LambdaFillLinear {
|
2018-09-19 07:58:03 +08:00
|
|
|
TensorCoord v_;
|
|
|
|
Storage_ offset_;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
LambdaFillLinear(TensorCoord const& _v, Storage_ _offset) : v_(_v), offset_(_offset) {}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
Storage_ operator()(TensorCoord const& coord) {
|
|
|
|
return Storage_(v_.template dot<int>(coord)) + offset_;
|
|
|
|
}
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// computes elements as a linear combination of their coordinates
|
2018-09-19 07:58:03 +08:00
|
|
|
void fill_linear(TensorCoord v, Storage_ offset = Storage_(0)) {
|
2018-05-17 02:44:56 +08:00
|
|
|
LambdaFillLinear lambda(v, offset);
|
|
|
|
elementwise_generate(lambda);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// computes elements as a linear combination of their coordinates
|
2018-09-19 07:58:03 +08:00
|
|
|
void fill_sequential(Storage_ v = Storage_(1), Storage_ offset = Storage_(0)) {
|
|
|
|
int const count = this->size().count();
|
2018-05-17 02:44:56 +08:00
|
|
|
for (int i = 0; i < count; ++i) {
|
2018-09-19 07:58:03 +08:00
|
|
|
this->data()[i] = Storage_(i);
|
2018-05-17 02:44:56 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Returns a constant value
|
|
|
|
struct LambdaFillValue {
|
2018-09-19 07:58:03 +08:00
|
|
|
Storage_ value;
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
LambdaFillValue(Storage_ _value) : value(_value) {}
|
2018-05-17 02:44:56 +08:00
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
Storage_ operator()() { return value; }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// fills with a value
|
2018-09-19 07:58:03 +08:00
|
|
|
void fill(Storage_ val = Storage_(0)) {
|
2018-05-17 02:44:56 +08:00
|
|
|
LambdaFillValue op(val);
|
|
|
|
elementwise_stream(op);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Conversion from Src to T
|
|
|
|
template <typename Src>
|
|
|
|
struct LambdaAssign {
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage_& a, Src b) const { a = Storage_(b); }
|
2018-05-17 02:44:56 +08:00
|
|
|
};
|
|
|
|
|
|
|
|
/// copies from external data source and performs type conversion
|
2018-09-19 07:58:03 +08:00
|
|
|
template <
|
|
|
|
typename SrcType,
|
|
|
|
typename SrcMapFunc_,
|
|
|
|
int SrcStorageRank_,
|
|
|
|
typename SrcIndex_,
|
|
|
|
typename SrcLongIndex_
|
|
|
|
>
|
|
|
|
void fill(
|
|
|
|
TensorView<SrcType, kRank, SrcMapFunc_, SrcStorageRank_, SrcIndex_, SrcLongIndex_> const& tensor) {
|
|
|
|
|
|
|
|
LambdaAssign<SrcType> op;
|
2018-05-17 02:44:56 +08:00
|
|
|
elementwise_in_place(op, tensor);
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Computes a norm
|
|
|
|
struct LambdaNorm {
|
|
|
|
double sum;
|
|
|
|
|
|
|
|
LambdaNorm() : sum(0) {}
|
|
|
|
|
2018-09-19 07:58:03 +08:00
|
|
|
void operator()(Storage const& element) {
|
2018-05-17 02:44:56 +08:00
|
|
|
double value(element);
|
|
|
|
double conj(element); // TODO - conjugates for complex
|
|
|
|
|
|
|
|
sum += value * conj;
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
/// Computes the norm of the matrix in double-precision
|
|
|
|
double norm() const {
|
|
|
|
LambdaNorm op;
|
|
|
|
elementwise_in_place(op);
|
|
|
|
|
|
|
|
return std::sqrt(op.sum);
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
|
|
|
|
} // namespace cutlass
|
2018-09-19 07:58:03 +08:00
|
|
|
|