Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_ref.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <typeinfo>
31 
32 #include <cutlass/coord.h>
33 #include <cutlass/cutlass.h>
34 #include <cutlass/vector.h>
35 
36 namespace cutlass {
37 
39 
41 template <typename Storage_, int Rank_>
42 class TensorRef {
43  public:
45  typedef Storage_ Storage;
46 
48  static int const Rank = Rank_;
49 
50  private:
51  //
52  // Data members
53  //
54 
56  Storage* ptr_;
57 
59  Coord<Rank> stride_;
60 
61  public:
62  //
63  // Methods
64  //
65 
68  TensorRef() : ptr_(nullptr) {}
69 
72  TensorRef(Storage* ptr, Coord<Rank> stride) : ptr_(ptr), stride_(stride) {}
73 
76  void reset(Storage* ptr = nullptr, Coord<Rank> stride = Coord<Rank>(0)) {
77  ptr_ = ptr;
78  stride_ = stride;
79  }
80 
82  template <typename T>
84  Coord<Rank> converted_stride;
85  for (int i = 0; i < Rank - 1; ++i) {
86  converted_stride[i] = stride_[i] * Extent<Storage>::kValue / Extent<T>::kValue;
87  }
88  converted_stride[Rank - 1] = stride_[Rank - 1];
89 
90  return TensorRef<T, Rank>(reinterpret_cast<T*>(ptr_), converted_stride);
91  }
92 
95  bool good() const { return ptr_ != nullptr; }
96 
99  Storage* data() const { return ptr_; }
100 
103  Coord<Rank> const& stride() const { return stride_; }
104 
107  int const& stride(int dim) const { return stride_.at(dim); }
108 
111  int leading_dim() const { return __NV_STD_MAX(stride_[1], stride_[2]); }
112 
115  long long offset(Coord<Rank> const& coord) const {
116  return stride_.template dot<long long>(coord);
117  }
118 
121  Storage& at(Coord<Rank> const& coord) const { return ptr_[offset(coord)]; }
122 
124  Storage& operator[](Coord<Rank> const& coord) const { return at(coord); }
125 
128  Storage& at(int idx) const { return ptr_[idx]; }
129 
131  Storage& operator[](int idx) const { return at(idx); }
132 
136  ptr_ += offset(b);
137  return *this;
138  }
139 
142  TensorRef operator+(Coord<Rank> const& b) const { return TensorRef(ptr_ + offset(b), stride_); }
143 
146  TensorRef operator-(Coord<Rank> const& b) const { return TensorRef(ptr_ - offset(b), stride_); }
147 };
148 
150 
151 } // namespace cutlass
CUTLASS_HOST_DEVICE int const & stride(int dim) const
Returns the stride of the tensor in the given dimension.
Definition: tensor_ref.h:107
Storage & operator[](int idx) const
Element-wise accessor.
Definition: tensor_ref.h:131
Definition: convert.h:33
CUTLASS_HOST_DEVICE Storage & at(Coord< Rank > const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:121
CUTLASS_HOST_DEVICE TensorRef & advance(Coord< Rank > const &b)
Adds an offset to the pointer.
Definition: tensor_ref.h:135
static int const Rank
Rank of tensor.
Definition: tensor_ref.h:48
CUTLASS_HOST_DEVICE TensorRef operator+(Coord< Rank > const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:142
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
Storage_ Storage
Data type of individual access.
Definition: tensor_ref.h:45
CUTLASS_HOST_DEVICE TensorRef operator-(Coord< Rank > const &b) const
Returns a TensorRef offset by a given amount.
Definition: tensor_ref.h:146
#define __NV_STD_MAX(a, b)
Select maximum(a, b)
Definition: platform.h:155
CUTLASS_HOST_DEVICE int leading_dim() const
Returns the maximum stride element as the &#39;leading dimension&#39;.
Definition: tensor_ref.h:111
CUTLASS_HOST_DEVICE Storage * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:99
CUTLASS_HOST_DEVICE TensorRef(Storage *ptr, Coord< Rank > stride)
Constructs from a pointer, size, and stride.
Definition: tensor_ref.h:72
Storage & operator[](Coord< Rank > const &coord) const
Element-wise accessor.
Definition: tensor_ref.h:124
#define nullptr
nullptr
Definition: platform.h:136
CUTLASS_HOST_DEVICE long long offset(Coord< Rank > const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:115
Structure modeling a pointer and stride into a tensor.
Definition: tensor_ref.h:42
TensorRef< T, Rank > convert()
Conversion function.
Definition: tensor_ref.h:83
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE bool good() const
Returns true if the TensorRef may be safely accessed.
Definition: tensor_ref.h:95
Defines a 1D vector of elements held in the registers of each thread.
CUTLASS_HOST_DEVICE void reset(Storage *ptr=nullptr, Coord< Rank > stride=Coord< Rank >(0))
Updates the pointer, stride, and location within a TensorRef.
Definition: tensor_ref.h:76
CUTLASS_HOST_DEVICE int & at()
Gets the index of a given Coord element.
Definition: coord.h:185
CUTLASS_HOST_DEVICE Coord< Rank > const & stride() const
Returns the stride of the tensor.
Definition: tensor_ref.h:103
Basic include for CUTLASS macros.
CUTLASS_HOST_DEVICE Storage & at(int idx) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:128
CUTLASS_HOST_DEVICE TensorRef()
Default ctor.
Definition: tensor_ref.h:68
Returns the extent of a scalar or vector.
Definition: vector.h:161