Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
tensor_ref_collection.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/tensor_ref.h"
32 
33 namespace cutlass {
34 
36 //
37 // TensorRefCollection is a concept for storing a logical collection of TensorRef objects. Classes
38 // satisfying the TensorRefCollection concept must support the following:
39 //
40 // // Define storage type
41 // typedef typename TensorRefCollection::Storage Storage;
42 //
43 // // Define a type for offsets in memory
44 // typedef typename TensorRefCollection::LongIndex LongIndex;
45 //
46 // // Define a ConstIterator type satisfying TensorRefIterator
47 // typedef typename TensorRefCollection::ConstIterator TensorRefIterator;
48 //
49 // // Implement a begin() method.
50 // TensorRefIterator iterator = collection.begin();
51 //
52 //
53 // TensorRefIterator is a concept for accessing an element in a TensorRefCollection. Classes
54 // satisfying the TensorRefIterator concept must support the following:
55 //
56 // // Define a TensorRef type accessed by the iterator
57 // typedef typename TensorRefIterator::TensorRef TensorRef;
58 //
59 // // Access the TensorRef
60 // TensorRef ref = *iterator;
61 //
62 // // Pre-increment and post-increment
63 // ++iterator;
64 // iterator++;
65 //
66 // // Pre-decrement and post-decrement
67 // --iterator;
68 // iterator--;
69 //
71 
74 template <
76  typename Storage_,
78  int Rank_,
80  typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
82  int StorageRank_ = MapFunc_::kStorageRank,
84  typename Index_ = int,
86  typename LongIndex_ = long long
87 >
89  public TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> {
90 
91  //
92  // Type definitions
93  //
94 
97 
99  typedef typename Base::Storage Storage;
100 
102  typedef Index_ Index;
103 
105  typedef LongIndex_ LongIndex;
106 
109 
111  typedef Base TensorRef;
112 
115  public:
117  typedef Base TensorRef;
118 
119  private:
120 
122  TensorRefBatchStrided const &ref_;
123 
125  LongIndex offset_;
126 
127  public:
128 
132  TensorRefBatchStrided const &ref,
133  LongIndex offset = 0): ref_(ref), offset_(offset) { }
134 
137  TensorRef *operator() const {
138  TensorRef ref(ref_);
139  ref.add_pointer_offset(offset_);
140  return ref;
141  }
142 
146  offset_ += ref_.tensor_stride;
147  return *this;
148  }
149 
153  ConstIterator ret(*this);
154  offset_ += ref_.tensor_stride;
155  return ret;
156  }
157 
161  return ConstIterator(ref, offset_ + ref_.tensor_stride * idx);
162  }
163 
167  offset_ += ref_.tensor_stride * idx;
168  return *this;
169  }
170 
174  offset_ -= ref_.tensor_stride;
175  return *this;
176  }
177 
181  ConstIterator ret(*this);
182  offset_ -= ref_.tensor_stride;
183  return ret;
184  }
185 
189  return ConstIterator(ref_, offset_ - ref_.tensor_stride * idx);
190  }
191 
195  offset_ -= ref_.tensor_stride * idx;
196  return *this;
197  }
198 
201  Stride operator-(ConstIterator const &it) {
202  return offset_ - it.offset_;
203  }
204  };
205 
206  //
207  // Data members
208  //
209 
212 
213  //
214  // Methods
215  //
216 
217  // Default ctor
220 
221  // Constructs form a tensor reference and
223  TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0):
224  TensorRef(ref),
225  tensor_stride(_tensor_stride) { }
226 
230  return idx * tensor_stride;
231  }
232 
233  // Returns a reference
235  TensorRef at(Index idx) const {
236  TensorRef ref(*this);
237  ref.add_pointer_offset(get_pointer_offset(idx));
238  return ref;
239  }
240 
244  return ConstIterator(*this);
245  }
246 };
247 
249 
258 template <
260  typename Storage_,
262  int Rank_,
264  typename MapFunc_ = IdentityTensorMapFunc<Rank_>,
266  int StorageRank_ = MapFunc_::kStorageRank,
268  typename Index_ = int,
270  typename LongIndex_ = long long
271 >
273  //
274  // Type definitions
275  //
276 
279 
281  typedef Storage_ Storage;
282 
284  typedef Index_ Index;
285 
287  typedef LongIndex_ LongIndex;
288 
291 
294  public:
295 
297  typedef Base TensorRef;
298 
299  private:
301  TensorRefArray const &ref_;
302 
304  int idx_;
305 
306  public:
307 
310  ConstIterator(TensorArrayRef const &ref, int idx = 0): ref_(ref), idx_(idx) { }
311 
314  TensorRef *operator() const {
315  return ref_.reference(idx_);
316  }
317 
321  ++idx_;
322  return *this;
323  }
324 
328  ConstIterator ret(*this);
329  idx_ ++;
330  return ret;
331  }
332 
335  return ConstIterator(ref_, idx_ + idx);
336  }
337 
340  idx_ += idx;
341  return *this;
342  }
343 
346  --idx_;
347  return *this;
348  }
349 
353  ConstIterator ret(*this);
354  --idx_;
355  return ret;
356  }
357 
360  idx_ -= idx;
361  return *this;
362  }
363 
366  return ConstIterator(ref_, idx_ + idx);
367  }
368  };
369 
370  //
371  // Data members
372  //
373 
376 
379 
380  //
381  // Methods
382  //
383 
384  // Default ctor
387 
388  // Construct from pointers to arrays to strides
391  Storage **_pointers,
392  Index _strides[kStorageRank - 1]): pointers(_pointers) {
393 
394  // Copy pointers to strides arrays
395  for (int i = 0; i < kStorageRank - 1; ++i) {
396  strides[i] = _strides[i];
397  }
398  }
399 
400  // Returns a TensorRef at the given index in the collection
402  TensorRef at(Index idx) const {
403  Coord<kStorageRank - 1, Index> stride;
405  for (int i = 0; i < kStorageRank - 1; ++i) {
406  stride[i] = stride_[idx][i];
407  }
408  return TensorRef(pointers[idx], stride);
409  }
410 
414  return ConstIterator(*this);
415  }
416 };
417 
419 
420 } // namespace cutlass
Constant iterator over tensors implied by TensorRefBatchStrided.
Definition: tensor_ref_collection.h:114
Definition: convert.h:33
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Moves to the previous tensor.
Definition: tensor_ref_collection.h:173
Defines a structure containing strides, bounds, and a pointer to tensor data.
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Definition: tensor_ref_collection.h:345
Index * strides[kStorageRank - 1]
Array of strides.
Definition: tensor_ref_collection.h:378
Storage_ Storage
Element pointed to by the TensorRef.
Definition: tensor_ref_collection.h:281
Definition: tensor_ref_collection.h:272
static int const kStorageRank
Rank of the stride vector.
Definition: tensor_ref_collection.h:290
Base::Storage Storage
Storage type.
Definition: tensor_ref_collection.h:99
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Definition: tensor_ref_collection.h:359
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:287
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an iterator.
Definition: tensor_ref_collection.h:243
LongIndex tensor_stride
Stride between tensors.
Definition: tensor_ref_collection.h:211
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:152
Base TensorRef
Tensor reference implied by the TensorRefBatchStrided.
Definition: tensor_ref_collection.h:111
CUTLASS_HOST_DEVICE TensorRef * operator() const
Obtains a TensorRef pointed to by the iterator.
Definition: tensor_ref_collection.h:137
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Moves to the previous tensor.
Definition: tensor_ref_collection.h:180
CUTLASS_HOST_DEVICE ConstIterator begin()
Returns an TesnorRefIterator over the TensorRef objects in this collection.
Definition: tensor_ref_collection.h:413
TensorRefIterator over TensorRef objects in TensorRefArray.
Definition: tensor_ref_collection.h:293
Index_ Index
Index type.
Definition: tensor_ref.h:146
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:62
CUTLASS_HOST_DEVICE ConstIterator(TensorArrayRef const &ref, int idx=0)
Constructs a ConstIterator over the TensorRef objects.
Definition: tensor_ref_collection.h:310
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > TensorRef
TensorRef type obtained from the TensorRefArray.
Definition: tensor_ref_collection.h:278
Index_ Index
Index type.
Definition: tensor_ref_collection.h:284
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Definition: tensor_ref_collection.h:365
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref.h:149
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:352
CUTLASS_HOST_DEVICE Stride operator-(ConstIterator const &it)
Returns the difference in offset between two iterators.
Definition: tensor_ref_collection.h:201
CUTLASS_HOST_DEVICE TensorRef * operator() const
Obtains a TensorRef pointed to by this iterator.
Definition: tensor_ref_collection.h:314
CUTLASS_HOST_DEVICE TensorRef at(Index idx) const
Definition: tensor_ref_collection.h:402
LongIndex_ LongIndex
Typically, strides in memory can be very large.
Definition: tensor_ref_collection.h:105
static int const kStorageRank
Rank of internal storage.
Definition: tensor_ref.h:143
Coord< kRank > TensorCoord
Coordinate in logical tensor space.
Definition: tensor_ref_collection.h:108
CUTLASS_HOST_DEVICE TensorRef at(Index idx) const
Definition: tensor_ref_collection.h:235
CUTLASS_HOST_DEVICE ConstIterator operator-(Index idx)
Returns an iterator moved forward by (idx) amount.
Definition: tensor_ref_collection.h:188
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
CUTLASS_HOST_DEVICE TensorRefBatchStrided()
Definition: tensor_ref_collection.h:219
CUTLASS_HOST_DEVICE LongIndex offset(TensorCoord const &coord) const
Computes the offset of an index from the origin of the tensor.
Definition: tensor_ref.h:331
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
Index_ Index
Index type.
Definition: tensor_ref_collection.h:102
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Returns an iterator advanced by (idx) amount.
Definition: tensor_ref_collection.h:160
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Advances this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:166
Storage ** pointers
Base addresses.
Definition: tensor_ref_collection.h:375
CUTLASS_HOST_DEVICE LongIndex get_pointer_offset(Index idx) const
Gets the pointer offset.
Definition: tensor_ref_collection.h:229
Definition: tensor_ref_collection.h:88
CUTLASS_HOST_DEVICE ConstIterator operator+(Index idx)
Definition: tensor_ref_collection.h:334
CUTLASS_HOST_DEVICE ConstIterator & operator-=(Index idx)
Moves this iterator by (idx) and returns a reference to self.
Definition: tensor_ref_collection.h:194
Base TensorRef
TensorRef returned by the iterator.
Definition: tensor_ref_collection.h:117
CUTLASS_HOST_DEVICE TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride=0)
Definition: tensor_ref_collection.h:223
CUTLASS_HOST_DEVICE ConstIterator & operator+=(Index idx)
Definition: tensor_ref_collection.h:339
CUTLASS_HOST_DEVICE ConstIterator(TensorRefBatchStrided const &ref, LongIndex offset=0)
Constructs a ConstIterator from a parent TensorRefBatchStrided.
Definition: tensor_ref_collection.h:131
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances the iterator to point to the next tensor.
Definition: tensor_ref_collection.h:145
TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Base
Underlying TensorRef type.
Definition: tensor_ref_collection.h:96
CUTLASS_HOST_DEVICE TensorArrayRef()
Definition: tensor_ref_collection.h:386
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Advances to next TensorRef.
Definition: tensor_ref_collection.h:320
CUTLASS_HOST_DEVICE TensorArrayRef(Storage **_pointers, Index _strides[kStorageRank - 1])
Definition: tensor_ref_collection.h:390
Base TensorRef
TensorRef returned by the iterator.
Definition: tensor_ref_collection.h:297
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Advances to next TensorRef.
Definition: tensor_ref_collection.h:327