Cutlass
CUDA Templates for Linear Algebra Subroutines and Solvers
predicate_vector.h
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <stdint.h>
32 
33 #include <cutlass/cutlass.h>
34 #include <cutlass/shape.h>
35 
36 #include <cutlass/util/platform.h>
37 
38 namespace cutlass {
39 
41 
58 
78 
94 
97 template <
99  int kPredicates_,
101  int kPredicatesPerByte_ = 4,
103  int kPredicateStart_ = 0>
106  static int const kPredicates = kPredicates_;
107 
109  static int const kPredicatesPerByte = kPredicatesPerByte_;
110 
112  static int const kPredicateStart = kPredicateStart_;
113 
114  // Make sure no one tries to put more than 8 bits in a byte :)
115  static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
116  // Make sure the "offsetted" bits fit in one byte.
118  "The offsetted predicates must fit within an actual byte.");
119 
121  typedef uint32_t Storage;
122 
125 
127  static int const kWordCount = (kBytes + sizeof(Storage) - 1) / sizeof(Storage);
128 
129  private:
130  //
131  // Data members
132  //
133 
135  Storage storageData[kWordCount];
136 
137  //
138  // Methods
139  //
140 
142  CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
144 
145  int byte = (idx / kPredicatesPerByte);
146  int bit_offset = (idx % kPredicatesPerByte);
147 
148  word = byte / sizeof(Storage);
149  int byte_offset = (byte % sizeof(Storage));
150 
151  bit = byte_offset * 8 + bit_offset + kPredicateStart;
152  }
153 
155  CUTLASS_HOST_DEVICE Storage &storage(int word) {
156  CUTLASS_ASSERT(word < kWordCount);
157  return storageData[word];
158  }
159 
161  CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
162  CUTLASS_ASSERT(word < kWordCount);
163  return storageData[word];
164  }
165 
166  public:
167  //
168  // Iterator
169  //
170 
178  PredicateVector const &vec_;
179 
181  int bit_;
182 
183  public:
186  ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
187 
190  ConstIterator(PredicateVector const &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
191 
195  ++bit_;
196  return *this;
197  }
198 
202  --bit_;
203  return *this;
204  }
205 
209  ConstIterator ret(*this);
210  ret.bit_++;
211  return ret;
212  }
213 
217  ConstIterator ret(*this);
218  ret.bit_--;
219  return ret;
220  }
221 
224  bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
225 
228  bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
229 
232  bool operator*() const { return vec_[bit_]; }
233  };
234 
240  class Iterator {
242  PredicateVector &vec_;
243 
245  int bit_;
246 
247  public:
250  Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
251 
254  Iterator(PredicateVector &_vec, int _start = 0) : vec_(_vec), bit_(_start) {}
255 
259  ++bit_;
260  return *this;
261  }
262 
266  --bit_;
267  return *this;
268  }
269 
273  Iterator ret(*this);
274  ret.bit_++;
275  return ret;
276  }
277 
281  Iterator ret(*this);
282  ret.bit_--;
283  return ret;
284  }
285 
288  bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
289 
292  bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
293 
296  bool get() { return vec_[bit_]; }
297 
300  bool operator*() const { return vec_[bit_]; }
301 
304  void set(bool value = true) { vec_.set(bit_, value); }
305  };
306 
312 
315  TrivialIterator(Iterator const &it) {}
316 
320 
323  TrivialIterator &operator++() { return *this; }
324 
327  TrivialIterator operator++(int) { return *this; }
328 
331  bool operator*() const { return true; }
332  };
333 
334  public:
335  //
336  // Methods
337  //
338 
340  CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
341 
343  CUTLASS_HOST_DEVICE void fill(bool value = true) {
344  Storage item = (value ? ~Storage(0) : Storage(0));
345 
347  for (int i = 0; i < kWordCount; ++i) {
348  storage(i) = item;
349  }
350  }
351 
353  CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
354 
356  CUTLASS_HOST_DEVICE bool at(int idx) const {
357  int bit, word;
358  computeStorageOffset(word, bit, idx);
359 
360  return ((storage(word) >> bit) & 1);
361  }
362 
364  CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
365  int bit, word;
366  computeStorageOffset(word, bit, idx);
367 
368  Storage disable_mask = (~(Storage(1) << bit));
369  Storage enable_mask = (Storage(value) << bit);
370 
371  storage(word) = ((storage(word) & disable_mask) | enable_mask);
372  }
373 
377  for (int i = 0; i < kWordCount; ++i) {
378  storage(i) = (storage(i) & predicates.storage(i));
379  }
380  return *this;
381  }
382 
386  for (int i = 0; i < kWordCount; ++i) {
387  storage(i) = (storage(i) | predicates.storage(i));
388  }
389  return *this;
390  }
391 
394  Storage mask(0);
395  for (int byte = 0; byte < sizeof(Storage); ++byte) {
396  Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
397  mask |= (byte_mask << (byte * 8));
398  }
399  uint32_t result = 0;
400  for (int word = 0; word < kWordCount; ++word) {
401  result |= storage(word);
402  }
403  return result == 0;
404  }
405 
407  CUTLASS_DEVICE
408  Iterator begin() { return Iterator(*this); }
409 
411  CUTLASS_DEVICE
412  Iterator end() { return Iterator(*this, kPredicates); }
413 
415  CUTLASS_DEVICE
416  ConstIterator const_begin() const { return ConstIterator(*this); }
417 
419  CUTLASS_DEVICE
420  ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
421 };
422 
424 
429 
431  CUTLASS_HOST_DEVICE bool at(int, int, int, int) const { return true; }
432 };
433 
435 
437 template <typename PredicateVector_, typename Iterations_>
440  typedef PredicateVector_ PredicateVector;
442  typedef Iterations_ Iterations;
443 
444  private:
446  PredicateVector &predicates;
447 
448  public:
450  CUTLASS_DEVICE PredicateTileAdapter(PredicateVector &predicates_) : predicates(predicates_) {}
451 
453  CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
454  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
455  return predicates.at(bit);
456  }
457 
459  CUTLASS_DEVICE void set(int d, int h, int w, int c, bool value) {
460  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
461  predicates.set(bit, value);
462  }
463 };
464 
466 
468 template <typename PredicateVector_, typename Iterations_>
471  typedef PredicateVector_ PredicateVector;
473  typedef Iterations_ Iterations;
474 
475  private:
477  PredicateVector const &predicates;
478 
479  public:
481  CUTLASS_DEVICE ConstPredicateTileAdapter(PredicateVector const &predicates_)
482  : predicates(predicates_) {}
483 
485  CUTLASS_DEVICE bool at(int d, int h, int w, int c) const {
486  int const bit = ComputeOffsetFromShape<Iterations>::get(d, h, w, c);
487  return predicates.at(bit);
488  }
489 };
490 
492 
493 } // namespace cutlass
CUTLASS_HOST_DEVICE Iterator(PredicateVector &_vec, int _start=0)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:254
CUTLASS_HOST_DEVICE bool operator!=(ConstIterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:228
CUTLASS_HOST_DEVICE PredicateVector & operator|=(PredicateVector const &predicates)
Computes the union of two identical predicate vectors.
Definition: predicate_vector.h:384
CUTLASS_HOST_DEVICE TrivialIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:323
Definition: convert.h:33
CUTLASS_HOST_DEVICE bool is_zero() const
Returns true if entire predicate array is zero.
Definition: predicate_vector.h:393
uint32_t Storage
Storage type of individual elements.
Definition: predicate_vector.h:115
CUTLASS_HOST_DEVICE TrivialIterator(PredicateVector const &_vec)
Constructs an iterator from a PredicateVector.
Definition: predicate_vector.h:319
CUTLASS_HOST_DEVICE ConstIterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:201
static int const kBytes
Number of bytes needed.
Definition: predicate_vector.h:124
CUTLASS_DEVICE ConstIterator const_begin() const
Returns a ConstIterator.
Definition: predicate_vector.h:416
CUTLASS_HOST_DEVICE ConstIterator(PredicateVector const &_vec, int _start=0)
Definition: predicate_vector.h:190
CUTLASS_HOST_DEVICE bool at(int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:356
CUTLASS_HOST_DEVICE ConstIterator & operator++()
Pre-increment.
Definition: predicate_vector.h:194
PredicateVector_ PredicateVector
The vector of predicates.
Definition: predicate_vector.h:440
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:166
CUTLASS_HOST_DEVICE ConstIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:208
CUTLASS_HOST_DEVICE Iterator operator++(int)
Post-increment.
Definition: predicate_vector.h:272
Adapter to enable random access to predicates via logical coordinate within a tile.
Definition: predicate_vector.h:438
CUTLASS_HOST_DEVICE TrivialIterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:315
C++ features that may be otherwise unimplemented for CUDA device functions.
Iterator that always returns true.
Definition: predicate_vector.h:308
CUTLASS_HOST_DEVICE TrivialIterator operator++(int)
Post-increment.
Definition: predicate_vector.h:327
CUTLASS_HOST_DEVICE bool operator==(Iterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:288
CUTLASS_DEVICE PredicateTileAdapter(PredicateVector &predicates_)
Ctor.
Definition: predicate_vector.h:450
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const
Get the value at location (d, h, w, c).
Definition: predicate_vector.h:453
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:60
CUTLASS_DEVICE bool at(int d, int h, int w, int c) const
Get the value at location (d, h, w, c).
Definition: predicate_vector.h:485
CUTLASS_HOST_DEVICE Iterator & operator--()
Pre-decrement.
Definition: predicate_vector.h:265
PredicateVector_ PredicateVector
The vector of predicates.
Definition: predicate_vector.h:471
CUTLASS_HOST_DEVICE PredicateVector & operator &=(PredicateVector const &predicates)
Computes the intersection of two identical predicate vectors.
Definition: predicate_vector.h:375
CUTLASS_HOST_DEVICE Iterator(Iterator const &it)
Copy constructor.
Definition: predicate_vector.h:250
CUTLASS_HOST_DEVICE bool operator[](int idx) const
Accesses a bit within the predicate vector.
Definition: predicate_vector.h:353
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:300
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:331
CUTLASS_HOST_DEVICE void fill(bool value=true)
Fills all predicates with a given value.
Definition: predicate_vector.h:343
static int const kPredicates
Number of bits stored by the PredicateVector.
Definition: predicate_vector.h:106
CUTLASS_DEVICE Iterator end()
Returns an iterator.
Definition: predicate_vector.h:412
#define CUTLASS_ASSERT(x)
Definition: cutlass.h:64
CUTLASS_HOST_DEVICE bool at(int, int, int, int) const
The value at location (d, h, w, c).
Definition: predicate_vector.h:431
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
static int const kPredicatesPerByte
Number of bits stored within each byte of the predicate bit vector.
Definition: predicate_vector.h:109
#define static_assert(__e, __m)
Definition: platform.h:145
Statically sized array of bits implementing.
Definition: predicate_vector.h:104
static int const kWordCount
Number of storage elements needed.
Definition: predicate_vector.h:127
CUTLASS_DEVICE ConstIterator const_end() const
Returns a ConstIterator.
Definition: predicate_vector.h:420
Always returns true predicate.
Definition: predicate_vector.h:426
CUTLASS_HOST_DEVICE Iterator & operator++()
Pre-increment.
Definition: predicate_vector.h:258
A const iterator implementing Predicate Iterator Concept enabling sequential read-only access to pred...
Definition: predicate_vector.h:176
CUTLASS_HOST_DEVICE void set(int idx, bool value=true)
Set a bit within the predicate vector.
Definition: predicate_vector.h:364
CUTLASS_HOST_DEVICE bool operator==(ConstIterator const &it) const
Returns true if iterators point to the same bit.
Definition: predicate_vector.h:224
Iterations_ Iterations
The iterations.
Definition: predicate_vector.h:473
Iterations_ Iterations
The iterations.
Definition: predicate_vector.h:442
CUTLASS_HOST_DEVICE bool operator*() const
Dereferences iterator.
Definition: predicate_vector.h:232
CUTLASS_HOST_DEVICE bool operator!=(Iterator const &it) const
Returns false if iterators point to the same bit.
Definition: predicate_vector.h:292
static int const kPredicateStart
First bit withing each byte containing predicates.
Definition: predicate_vector.h:112
CUTLASS_HOST_DEVICE ConstIterator(ConstIterator const &it)
Copy constructor.
Definition: predicate_vector.h:186
CUTLASS_HOST_DEVICE TrivialPredicateTileAdapter()
Ctor.
Definition: predicate_vector.h:428
CUTLASS_HOST_DEVICE ConstIterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:216
Adapter to enable random access to predicates via logical coordinate within a tile.
Definition: predicate_vector.h:469
CUTLASS_DEVICE ConstPredicateTileAdapter(PredicateVector const &predicates_)
Ctor.
Definition: predicate_vector.h:481
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
CUTLASS_HOST_DEVICE PredicateVector(bool value=true)
Initialize the predicate vector.
Definition: predicate_vector.h:340
CUTLASS_DEVICE Iterator begin()
Returns an iterator to the start of the bit vector.
Definition: predicate_vector.h:408
Basic include for CUTLASS macros.
An iterator implementing Predicate Iterator Concept enabling sequential read and write access to pred...
Definition: predicate_vector.h:240
CUTLASS_HOST_DEVICE Iterator operator--(int)
Post-decrement.
Definition: predicate_vector.h:280
CUTLASS_HOST_DEVICE TrivialIterator()
Constructor.
Definition: predicate_vector.h:311