cutlass/include/cutlass/predicate_vector.h
ANIKET SHIVAM 66d9cddc83
New updates for 2.11 (#775)
* New updates.

* Minor profiler updates

Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-01-20 16:32:57 -05:00

525 lines
15 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines container classes and iterators for managing a statically sized vector
of boolean predicates.
*/
#pragma once
#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#include <cuda/std/cstdint>
#else
#include <assert.h>
#include <stdint.h>
#endif
#include "cutlass/cutlass.h"
#include "cutlass/platform/platform.h"
namespace cutlass {
////////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup predicate_vector_concept Predicate Vector Concept
@{
Implementations of \ref predicate_vector_concept contain an ordered set of boolean predicates which
may be used as conditionals in other device-side operations. Both random access and iterators
offering sequential access are provided.
@par Predicate Vector
A \ref predicate_vector_concept satisfies the following expressions
- <b>at(int idx)</b> - returns the value of the indexed predicate
- <b>set(int idx, bool value)</b> - sets the value of the indexed predicate
- <b>begin()</b> - returns a \ref predicate_iterator_concept pointing to the first predicate
@}
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup predicate_iterator_concept Predicate Iterator Concept
@{
Implementations of \ref predicate_iterator_concept enables accessing and traversing elements of a
bit vector.
@par Const Predicate Iterator
A const \ref predicate_iterator_concept satisfies the following expressions
- <b>++it</b> increments the iterator to the next predicate
- <b>*it</b> returns the value of the currently pointed-to predicate
@par Mutable Predicate Iterator
A \ref predicate_iterator_concept that is non-const <b>also</b> satisfies the following expressions
- <b>it.set(bool value)</b> sets the value of the currently pointed-to predicate
@}
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept
@{
Implementations of \ref predicate_tile_adapter provide a mapping between a the elements of a \ref
tile_traits_concept and a \ref predicate_vector_concept.
@par Predicate Tile Adapter
A \ref predicate_tile_adapter satisfies the following expressions
- <b>at(int d, int h, int w, int c)</b> - returns the value of a predicate corresponding to the
access (d, h, w, c) within the tile.
@}
*/
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
template <
/// Number of predicates conatined in predicate vector
int kPredicates_,
/// Number of predicates contained in each byte of internal storage
int kPredicatesPerByte_ = 4,
/// Location of first predicate within byte of internal storage
int kPredicateStart_ = 0>
struct PredicateVector {
/// Number of bits stored by the PredicateVector
static int const kPredicates = kPredicates_;
/// Number of bits stored within each byte of the predicate bit vector
static int const kPredicatesPerByte = kPredicatesPerByte_;
/// First bit withing each byte containing predicates
static int const kPredicateStart = kPredicateStart_;
// Make sure no one tries to put more than 8 bits in a byte :)
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
// Make sure the "offsetted" bits fit in one byte.
static_assert(kPredicateStart + kPredicatesPerByte <= 8,
"The offsetted predicates must fit within an actual byte.");
/// Storage type of individual elements
typedef uint32_t Storage;
/// Number of bytes needed
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
/// Number of storage elements needed
static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
private:
//
// Data members
//
/// Words of bit vector
Storage storageData[kWordCount];
//
// Methods
//
/// Computes the word and bit corresponding to a logical predicate index
CUTLASS_HOST_DEVICE void computeStorageOffset(int &word, int &bit, int idx) const {
CUTLASS_ASSERT(idx < kPredicates);
int byte = (idx / kPredicatesPerByte);
int bit_offset = (idx % kPredicatesPerByte);
word = byte / sizeof(Storage);
int byte_offset = (byte % sizeof(Storage));
bit = byte_offset * 8 + bit_offset + kPredicateStart;
}
/// Accesses a given word with optional assertions
CUTLASS_HOST_DEVICE Storage &storage(int word) {
CUTLASS_ASSERT(word < kWordCount);
return storageData[word];
}
/// Accesses a given word with optional assertions
CUTLASS_HOST_DEVICE Storage const &storage(int word) const {
CUTLASS_ASSERT(word < kWordCount);
return storageData[word];
}
public:
//
// Iterator
//
/**
* @brief An iterator implementing \ref predicate_iterator_concept enabling sequential
* read and write access to predicates.
* @concept{predicate_iterator_concept}
*/
class Iterator {
/// Reference to PredicateVector instance
PredicateVector &vec_;
/// Index into PredicateVector
int bit_;
public:
/// Copy constructor
CUTLASS_HOST_DEVICE
Iterator(Iterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
/// Constructs an iterator from a PredicateVector
CUTLASS_HOST_DEVICE
Iterator(PredicateVector &vec, int _start = 0) : vec_(vec), bit_(_start) {}
/// Pre-increment
CUTLASS_HOST_DEVICE
Iterator &operator++() {
++bit_;
return *this;
}
/// Increment
CUTLASS_HOST_DEVICE
Iterator &operator+=(int offset) {
bit_ += offset;
return *this;
}
/// Pre-decrement
CUTLASS_HOST_DEVICE
Iterator &operator--() {
--bit_;
return *this;
}
/// Decrement
CUTLASS_HOST_DEVICE
Iterator &operator-=(int offset) {
bit_ -= offset;
return *this;
}
/// Post-increment
CUTLASS_HOST_DEVICE
Iterator operator++(int) {
Iterator ret(*this);
ret.bit_++;
return ret;
}
/// Post-decrement
CUTLASS_HOST_DEVICE
Iterator operator--(int) {
Iterator ret(*this);
ret.bit_--;
return ret;
}
/// Iterator advances by some amount
CUTLASS_HOST_DEVICE
Iterator operator+(int offset) {
Iterator ret(*this);
ret.bit_ += offset;
return ret;
}
/// Iterator recedes by some amount
CUTLASS_HOST_DEVICE
Iterator operator-(int offset) {
ConstIterator ret(*this);
ret.bit_ -= offset;
return ret;
}
/// Returns true if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator==(Iterator const &it) const { return bit_ == it.bit_; }
/// Returns false if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator!=(Iterator const &it) const { return bit_ != it.bit_; }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool get() { return vec_.at(bit_); }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool at() const { return vec_.at(bit_); }
/// Dereferences iterator
CUTLASS_HOST_DEVICE
bool operator*() const { return at(); }
/// Sets the bit at the pointed to location
CUTLASS_HOST_DEVICE
void set(bool value = true) { vec_.set(bit_, value); }
};
/**
* @brief An iterator implementing \ref predicate_iterator_concept enabling sequential
* read and write access to predicates.
* @concept{predicate_iterator_concept}
*/
class ConstIterator {
/// Reference to PredicateVector instance
PredicateVector const &vec_;
/// Index into PredicateVector
int bit_;
public:
/// Copy constructor
CUTLASS_HOST_DEVICE
ConstIterator(ConstIterator const &it) : vec_(it.vec_), bit_(it.bit_) {}
/// Constructs an iterator from a PredicateVector
CUTLASS_HOST_DEVICE
ConstIterator(PredicateVector const &vec, int _start = 0) : vec_(vec), bit_(_start) {}
/// Pre-increment
CUTLASS_HOST_DEVICE
ConstIterator &operator++() {
++bit_;
return *this;
}
/// Increment
CUTLASS_HOST_DEVICE
ConstIterator &operator+=(int offset) {
bit_ += offset;
return *this;
}
/// Pre-decrement
CUTLASS_HOST_DEVICE
ConstIterator &operator--() {
--bit_;
return *this;
}
/// Decrement
CUTLASS_HOST_DEVICE
ConstIterator &operator-=(int offset) {
bit_ -= offset;
return *this;
}
/// Post-increment
CUTLASS_HOST_DEVICE
ConstIterator operator++(int) {
ConstIterator ret(*this);
ret.bit_++;
return ret;
}
/// Post-decrement
CUTLASS_HOST_DEVICE
ConstIterator operator--(int) {
ConstIterator ret(*this);
ret.bit_--;
return ret;
}
/// Iterator advances by some amount
CUTLASS_HOST_DEVICE
ConstIterator operator+(int offset) {
ConstIterator ret(*this);
ret.bit_ += offset;
return ret;
}
/// Iterator recedes by some amount
CUTLASS_HOST_DEVICE
ConstIterator operator-(int offset) {
ConstIterator ret(*this);
ret.bit_ -= offset;
return ret;
}
/// Returns true if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator==(ConstIterator const &it) const { return bit_ == it.bit_; }
/// Returns false if iterators point to the same bit
CUTLASS_HOST_DEVICE
bool operator!=(ConstIterator const &it) const { return bit_ != it.bit_; }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool get() { return vec_.at(bit_); }
/// Gets the bit at the pointed to location
CUTLASS_HOST_DEVICE
bool at() const { return vec_.at(bit_); }
/// Dereferences iterator
CUTLASS_HOST_DEVICE
bool operator*() const { return at(); }
};
/// Iterator that always returns true
struct TrivialIterator {
/// Constructor
CUTLASS_HOST_DEVICE
TrivialIterator() {}
/// Copy constructor
CUTLASS_HOST_DEVICE
TrivialIterator(Iterator const &it) {}
/// Constructs an iterator from a PredicateVector
CUTLASS_HOST_DEVICE
TrivialIterator(PredicateVector const &_vec) {}
/// Pre-increment
CUTLASS_HOST_DEVICE
TrivialIterator &operator++() { return *this; }
/// Post-increment
CUTLASS_HOST_DEVICE
TrivialIterator operator++(int) { return *this; }
/// Dereferences iterator
CUTLASS_HOST_DEVICE
bool operator*() const { return true; }
};
public:
//
// Methods
//
/// Initialize the predicate vector
CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }
/// Fills all predicates with a given value
CUTLASS_HOST_DEVICE void fill(bool value = true) {
Storage item = (value ? ~Storage(0) : Storage(0));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWordCount; ++i) {
storage(i) = item;
}
}
/// Clears all predicates
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWordCount; ++i) {
storage(i) = 0;
}
}
/// Sets all predicates to true
CUTLASS_HOST_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWordCount; ++i) {
storage(i) = ~Storage(0);
}
}
/// Accesses a bit within the predicate vector.
CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }
/// Accesses a bit within the predicate vector.
CUTLASS_HOST_DEVICE bool at(int idx) const {
int bit, word;
computeStorageOffset(word, bit, idx);
return ((storage(word) >> bit) & 1);
}
/// Set a bit within the predicate vector.
CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
int bit, word;
computeStorageOffset(word, bit, idx);
Storage disable_mask = (~(Storage(1) << bit));
Storage enable_mask = (Storage(value) << bit);
storage(word) = ((storage(word) & disable_mask) | enable_mask);
}
/// Computes the intersection of two identical predicate vectors.
CUTLASS_HOST_DEVICE PredicateVector &operator&=(PredicateVector const &predicates) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWordCount; ++i) {
storage(i) = (storage(i) & predicates.storage(i));
}
return *this;
}
/// Computes the union of two identical predicate vectors.
CUTLASS_HOST_DEVICE PredicateVector &operator|=(PredicateVector const &predicates) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kWordCount; ++i) {
storage(i) = (storage(i) | predicates.storage(i));
}
return *this;
}
/// Returns true if entire predicate array is zero.
CUTLASS_HOST_DEVICE bool is_zero() const {
Storage mask(0);
for (int byte = 0; byte < sizeof(Storage); ++byte) {
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
mask |= (byte_mask << (byte * 8));
}
uint32_t result = 0;
for (int word = 0; word < kWordCount; ++word) {
result |= storage(word);
}
return result == 0;
}
/// Returns an iterator to the start of the bit vector
CUTLASS_DEVICE
Iterator begin() { return Iterator(*this); }
/// Returns an iterator
CUTLASS_DEVICE
Iterator end() { return Iterator(*this, kPredicates); }
/// Returns a ConstIterator
CUTLASS_DEVICE
ConstIterator const_begin() const { return ConstIterator(*this); }
/// Returns a ConstIterator
CUTLASS_DEVICE
ConstIterator const_end() const { return ConstIterator(*this, kPredicates); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass