375 lines
11 KiB
C++
375 lines
11 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
/*! \file
|
|
\brief Defines a structure containing strides, bounds, and a pointer to tensor data.
|
|
*/
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/complex.h"
|
|
#include "cutlass/tensor_ref.h"
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace cutlass {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Element_>
|
|
struct PlanarComplexReference {
|
|
|
|
//
|
|
// Type definitions
|
|
//
|
|
|
|
using Element = Element_;
|
|
using ComplexElement = complex<Element>;
|
|
|
|
//
|
|
// Data members
|
|
//
|
|
|
|
Element *real;
|
|
Element *imag;
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
PlanarComplexReference(
|
|
Element *real_ = nullptr,
|
|
Element *imag_ = nullptr
|
|
):
|
|
real(real_), imag(imag_) { }
|
|
|
|
/// Loads the complex element
|
|
CUTLASS_HOST_DEVICE
|
|
operator complex<Element>() const {
|
|
return complex<Element>{*real, *imag};
|
|
}
|
|
|
|
/// Stores a complex element to the location pointed to by the reference
|
|
CUTLASS_HOST_DEVICE
|
|
PlanarComplexReference &operator=(complex<Element> const &rhs) {
|
|
*real = rhs.real();
|
|
*imag = rhs.imag();
|
|
return *this;
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* \brief TensorRef is a template for objects pointing to the start of tensors of arbitrary rank
|
|
and layout within memory. A TensorRef combines a pointer and a Layout concept
|
|
|
|
*/
|
|
template <
|
|
/// Data type of element stored within tensor (concept: NumericType)
|
|
typename Element_,
|
|
/// Defines a mapping from logical coordinate to linear memory (concept: Layout)
|
|
typename Layout_
|
|
>
|
|
class TensorRefPlanarComplex {
|
|
public:
|
|
/// Data type of individual access
|
|
using Element = Element_;
|
|
|
|
/// Complex element type
|
|
using ComplexElement = complex<Element>;
|
|
|
|
/// Mapping function from logical coordinate to linear memory
|
|
using Layout = Layout_;
|
|
|
|
static_assert(sizeof_bits<Element>::value >= 8,
|
|
"Planar complex not suitable for subbyte elements at this time");
|
|
|
|
/// Reference type to an element
|
|
using Reference = PlanarComplexReference<Element>;
|
|
|
|
/// Logical rank of tensor index space
|
|
static int const kRank = Layout::kRank;
|
|
|
|
/// Index type
|
|
using Index = typename Layout::Index;
|
|
|
|
/// Long index used for pointer offsets
|
|
using LongIndex = typename Layout::LongIndex;
|
|
|
|
/// Coordinate in logical tensor space
|
|
using TensorCoord = typename Layout::TensorCoord;
|
|
|
|
/// Layout's stride vector
|
|
using Stride = typename Layout::Stride;
|
|
|
|
/// TensorRef to constant data
|
|
using ConstTensorRef = TensorRefPlanarComplex<
|
|
typename platform::remove_const<Element>::type const,
|
|
Layout>;
|
|
|
|
/// TensorRef to non-constant data
|
|
using NonConstTensorRef = TensorRefPlanarComplex<
|
|
typename platform::remove_const<Element>::type,
|
|
Layout>;
|
|
|
|
/// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a
|
|
/// scalar, but degenerate cases such as these are difficult to accommodate without
|
|
/// extensive C++ metaprogramming or support for zero-length arrays.
|
|
static_assert(kRank > 0, "Cannot define a zero-rank TensorRef");
|
|
|
|
private:
|
|
|
|
/// Pointer
|
|
Element* ptr_;
|
|
|
|
/// Layout object maps logical coordinates to linear offsets
|
|
Layout layout_;
|
|
|
|
/// Offset to imaginary part
|
|
LongIndex imaginary_stride_;
|
|
|
|
public:
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
/// Constructs a TensorRef with a pointer and layout object.
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex(
|
|
Element *ptr = nullptr, ///< pointer to start of tensor
|
|
Layout const &layout = Layout(), ///< layout object containing stride and mapping function
|
|
LongIndex imaginary_stride = 0
|
|
):
|
|
ptr_(ptr), layout_(layout), imaginary_stride_(imaginary_stride) {
|
|
|
|
}
|
|
|
|
/// Converting constructor from TensorRef to non-constant data.
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex(
|
|
NonConstTensorRef const &ref ///< TensorRef to non-const data
|
|
):
|
|
ptr_(ref.data()), layout_(ref.layout()), imaginary_stride_(ref.imaginary_stride_) { }
|
|
|
|
/// Returns a reference to constant-valued tensor.
|
|
CUTLASS_HOST_DEVICE
|
|
ConstTensorRef const_ref() const {
|
|
return ConstTensorRef(ptr_, layout_, imaginary_stride_);
|
|
}
|
|
|
|
CUTLASS_HOST_DEVICE
|
|
NonConstTensorRef non_const_ref() const {
|
|
return NonConstTensorRef(
|
|
const_cast<typename platform::remove_const<Element>::type *>(ptr_),
|
|
layout_,
|
|
imaginary_stride_);
|
|
}
|
|
|
|
/// Updates only the pointer
|
|
CUTLASS_HOST_DEVICE
|
|
void reset(Element* ptr = nullptr, LongIndex imaginary_stride = 0) {
|
|
ptr_ = ptr;
|
|
imaginary_stride_ = imaginary_stride;
|
|
}
|
|
|
|
/// Updates the pointer and layout object
|
|
CUTLASS_HOST_DEVICE
|
|
void reset(Element* ptr, Layout const &layout, LongIndex imaginary_stride) {
|
|
ptr_ = ptr;
|
|
layout_ = layout;
|
|
imaginary_stride_ = imaginary_stride;
|
|
}
|
|
|
|
/// Returns true if the TensorRef is non-null
|
|
CUTLASS_HOST_DEVICE
|
|
bool good() const {
|
|
return ptr_ != nullptr;
|
|
}
|
|
|
|
/// Returns the pointer to referenced data
|
|
CUTLASS_HOST_DEVICE
|
|
Element * data() const { return ptr_; }
|
|
|
|
/// Returns the pointer to referenced data
|
|
CUTLASS_HOST_DEVICE
|
|
Element * imaginary_data() const { return ptr_ + imaginary_stride_; }
|
|
|
|
/// Returns a reference to the element at a given linear index
|
|
CUTLASS_HOST_DEVICE
|
|
Reference data(LongIndex idx) const {
|
|
return Reference(ptr_ + idx, ptr_ + idx + imaginary_stride_);
|
|
}
|
|
|
|
/// Returns the layout object
|
|
CUTLASS_HOST_DEVICE
|
|
Layout & layout() {
|
|
return layout_;
|
|
}
|
|
|
|
/// Returns the layout object
|
|
CUTLASS_HOST_DEVICE
|
|
Layout layout() const {
|
|
return layout_;
|
|
}
|
|
|
|
/// Gets the stride to an imaginary element
|
|
LongIndex imaginary_stride() const {
|
|
return imaginary_stride_;
|
|
}
|
|
|
|
/// Gets the stride to an imaginary element
|
|
LongIndex &imaginary_stride() {
|
|
return imaginary_stride_;
|
|
}
|
|
|
|
/// Returns the layout object's stride vector
|
|
CUTLASS_HOST_DEVICE
|
|
Stride stride() const {
|
|
return layout_.stride();
|
|
}
|
|
|
|
/// Returns the layout object's stride vector
|
|
CUTLASS_HOST_DEVICE
|
|
Stride & stride() {
|
|
return layout_.stride();
|
|
}
|
|
|
|
/// Returns the layout object's stride in a given physical dimension
|
|
CUTLASS_HOST_DEVICE
|
|
Index stride(int dim) const {
|
|
return layout_.stride().at(dim);
|
|
}
|
|
|
|
/// Returns the layout object's stride in a given physical dimension
|
|
CUTLASS_HOST_DEVICE
|
|
Index & stride(int dim) {
|
|
return layout_.stride().at(dim);
|
|
}
|
|
|
|
/// Computes the offset of an index from the origin of the tensor
|
|
CUTLASS_HOST_DEVICE
|
|
LongIndex offset(TensorCoord const& coord) const {
|
|
return layout_(coord);
|
|
}
|
|
|
|
/// Returns a reference to the element at a given Coord
|
|
CUTLASS_HOST_DEVICE
|
|
Reference at(TensorCoord const& coord) const {
|
|
return data(offset(coord));
|
|
}
|
|
|
|
/// Returns a reference to the element at a given Coord
|
|
CUTLASS_HOST_DEVICE
|
|
Reference operator[](TensorCoord const& coord) const {
|
|
return data(offset(coord));
|
|
}
|
|
|
|
/// Adds an offset to each pointer
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex & add_pointer_offset(LongIndex offset_) {
|
|
ptr_ += offset_;
|
|
return *this;
|
|
}
|
|
|
|
/// Adds an offset to each pointer
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex & add_coord_offset(TensorCoord const &coord) {
|
|
add_pointer_offset(offset(coord));
|
|
return *this;
|
|
}
|
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex operator+(TensorCoord const& b) const {
|
|
TensorRefPlanarComplex result(*this);
|
|
result.add_coord_offset(b);
|
|
return result;
|
|
}
|
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex & operator+=(TensorCoord const& b) {
|
|
add_coord_offset(b);
|
|
return *this;
|
|
}
|
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex operator-(TensorCoord const& b) const {
|
|
TensorRefPlanarComplex result(*this);
|
|
result.add_pointer_offset(-offset(b));
|
|
return result;
|
|
}
|
|
|
|
/// Returns a TensorRef offset by a given amount
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex & operator-=(TensorCoord const& b) {
|
|
add_pointer_offset(-offset(b));
|
|
return *this;
|
|
}
|
|
|
|
/// TensorRef to real-valued tensor
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::TensorRef<Element, Layout> ref_real() const {
|
|
return cutlass::TensorRef<Element, Layout>(data(), layout());
|
|
}
|
|
|
|
/// TensorRef to real-valued tensor
|
|
CUTLASS_HOST_DEVICE
|
|
cutlass::TensorRef<Element, Layout> ref_imag() const {
|
|
return cutlass::TensorRef<Element, Layout>(imaginary_data(), layout());
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/// Constructs a TensorRef, deducing types from arguments.
|
|
template <
|
|
typename Element,
|
|
typename Layout
|
|
>
|
|
CUTLASS_HOST_DEVICE
|
|
TensorRefPlanarComplex<Element, Layout> make_TensorRefPlanarComplex(
|
|
Element *ptr,
|
|
Layout const &layout,
|
|
int64_t imaginary_stride) {
|
|
|
|
return TensorRefPlanarComplex<Element, Layout>(ptr, layout, imaginary_stride);
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass
|
|
|
|
///////////////////////////////////////////////////////////////////////////////////////////////////
|