Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_view.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cmath>
32 
33 #include <cutlass/cutlass.h>
34 #include <cutlass/tensor_ref.h>
35 
36 namespace cutlass {
37 
39 
41 template <typename T>
42 class TensorView : public TensorRef<T, 4> {
43  public:
46 
48  typedef Base TensorRef_t;
49 
52 
54  static int const Rank = TensorRef_t::Rank;
55 
57  typedef int Offset_t;
58 
61 
62  private:
63  //
64  // Data members
65  //
66 
68  TensorRef_t ref_;
69 
71  Coord_t size_;
72 
73  public:
74  //
75  // Device and Host Methods
76  //
77 
81 
84  TensorView(TensorRef_t const& _ref, Coord_t const& _size) : Base(_ref), size_(_size) {}
85 
88  bool good() const { return ref().good(); }
89 
92  T* data() const { return ref().data(); }
93 
96  void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) {
97  Base::operator=(_ref);
98  size_ = _size;
99  }
100 
103  TensorRef_t& ref() { return *this; }
104 
108 
111  TensorRef_t const& ref() const { return *this; }
112 
115  Coord_t const& size() const { return size_; }
116 
119  int size(int dim) const { return size_.at(dim); }
120 
123  Coord_t const& stride() const { return ref().stride(); }
124 
127  int const& stride(int dim) const { return ref().stride(dim); }
128 
131  TensorView& operator=(TensorView const& _tensor) {
132  Base::operator=(_tensor._ref);
133  size_ = _tensor.size_;
134  return *this;
135  }
136 
139  Offset_t offset(Coord_t const& coord) const { return ref().offset(coord); }
140 
143  bool contains(Coord_t const& coord) const {
144  for (int dim = 0; dim < Rank; ++dim) {
145  if (coord.at(dim) >= size_.at(dim)) {
146  return false;
147  }
148  }
149  return true;
150  }
151 
154  T& at(Coord_t const& coord) const { return ref().at(coord); }
155 
157  T& operator[](Coord<Rank> const& coord) const { return at(coord); }
158 
161  T& at(Offset_t idx) const { return ref().at(idx); }
162 
165  TensorView<T> subview(Coord_t const& location, Coord_t size) const {
166  return TensorView<T>(ref() + location, size.clamp(size_ - location));
167  }
168 };
169 
171 
172 } // namespace cutlass
CUTLASS_HOST_DEVICE TensorRef_t const & ref() const
Accesses the tensor reference pointing to data.
Definition: tensor_view.h:111
Definition: convert.h:33
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE Storage & at(Coord< Rank > const &coord) const
Returns a reference to the element at a given Coord.
Definition: tensor_ref.h:121
int Offset_t
Type used to compute the offset of an element to the base of a tensor.
Definition: tensor_view.h:57
static int const Rank
Rank of tensor.
Definition: tensor_ref.h:48
CUTLASS_HOST_DEVICE TensorView()
Default constructor.
Definition: tensor_view.h:80
CUTLASS_HOST_DEVICE int size(int dim) const
Accesses the size.
Definition: tensor_view.h:119
CUTLASS_HOST_DEVICE Coord & clamp(Coord< N > const &max, Coord< N > const &min=Coord< N >())
Clamps a coordinate to a range specified by maximum and minimum values.
Definition: coord.h:219
Coord< Rank > Coord_t
Coordinate into tensor.
Definition: tensor_view.h:60
CUTLASS_HOST_DEVICE void reset(TensorRef_t const &_ref=TensorRef_t(0), Coord_t const &_size=Coord_t())
Updates the reference and size of a Tensor_view object.
Definition: tensor_view.h:96
CUTLASS_HOST_DEVICE bool contains(Coord_t const &coord) const
Determines whether a location is within a tensor.
Definition: tensor_view.h:143
CUTLASS_HOST_DEVICE int const & stride(int dim) const
Accesses the stride.
Definition: tensor_view.h:127
static int const Rank
Rank of tensor.
Definition: tensor_view.h:54
CUTLASS_HOST_DEVICE T & at(Offset_t idx) const
Element-wise accessor.
Definition: tensor_view.h:161
CUTLASS_HOST_DEVICE ConstTensorRef_t const_ref()
Definition: tensor_view.h:107
CUTLASS_HOST_DEVICE Storage * data() const
Returns the pointer to referenced data.
Definition: tensor_ref.h:99
Host-side reference implementation of tensor operations.
Definition: tensor_view.h:42
CUTLASS_HOST_DEVICE long long offset(Coord< Rank > const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:115
Structure modeling a pointer and stride into a tensor.
Definition: tensor_ref.h:42
TensorRef< T, 4 > Base
Reference and stride.
Definition: tensor_view.h:45
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE bool good() const
Returns true if the Tensor_view is bound to some memory.
Definition: tensor_view.h:88
CUTLASS_HOST_DEVICE bool good() const
Returns true if the TensorRef may be safely accessed.
Definition: tensor_ref.h:95
CUTLASS_HOST_DEVICE Offset_t offset(Coord_t const &coord) const
Returns the index of an element.
Definition: tensor_view.h:139
CUTLASS_HOST_DEVICE T * data() const
Returns a pointer to data.
Definition: tensor_view.h:92
T & operator[](Coord< Rank > const &coord) const
Element-wise accessor.
Definition: tensor_view.h:157
Base TensorRef_t
Reference and stride.
Definition: tensor_view.h:48
CUTLASS_HOST_DEVICE int & at()
Gets the index of a given Coord element.
Definition: coord.h:185
CUTLASS_HOST_DEVICE T & at(Coord_t const &coord) const
Element-wise accessor.
Definition: tensor_view.h:154
CUTLASS_HOST_DEVICE Coord_t const & size() const
Accesses the size.
Definition: tensor_view.h:115
CUTLASS_HOST_DEVICE Coord_t const & stride() const
Accesses the stride.
Definition: tensor_view.h:123
CUTLASS_HOST_DEVICE TensorRef_t & ref()
Accesses the tensor reference pointing to data.
Definition: tensor_view.h:103
CUTLASS_HOST_DEVICE Coord< Rank > const & stride() const
Returns the stride of the tensor.
Definition: tensor_ref.h:103
CUTLASS_HOST_DEVICE TensorView & operator=(TensorView const &_tensor)
Assigns the Tensor_view.
Definition: tensor_view.h:131
Basic include for CUTLASS macros.
CUTLASS_HOST_DEVICE TensorView(TensorRef_t const &_ref, Coord_t const &_size)
Constructs a Tensor_view from a TensorRef and size.
Definition: tensor_view.h:84
TensorRef< T const, 4 > ConstTensorRef_t
Reference to constant type.
Definition: tensor_view.h:51
CUTLASS_HOST_DEVICE TensorView< T > subview(Coord_t const &location, Coord_t size) const
Returns a Tensor_view given location and size quantities.
Definition: tensor_view.h:165