Fix is_zero (#1147)
* Fix is_zero * Use constexpr * Add CUTLASS_PRAGMA_UNROLL to loops * Avoid if branches in is_zero
This commit is contained in:
parent
fb10fa5308
commit
7a7796afae
@ -106,7 +106,7 @@ tile_traits_concept and a \ref predicate_vector_concept.
|
|||||||
|
|
||||||
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
|
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
|
||||||
template <
|
template <
|
||||||
/// Number of predicates conatined in predicate vector
|
/// Number of predicates contained in predicate vector
|
||||||
int kPredicates_,
|
int kPredicates_,
|
||||||
/// Number of predicates contained in each byte of internal storage
|
/// Number of predicates contained in each byte of internal storage
|
||||||
int kPredicatesPerByte_ = 4,
|
int kPredicatesPerByte_ = 4,
|
||||||
@ -114,13 +114,13 @@ template <
|
|||||||
int kPredicateStart_ = 0>
|
int kPredicateStart_ = 0>
|
||||||
struct PredicateVector {
|
struct PredicateVector {
|
||||||
/// Number of bits stored by the PredicateVector
|
/// Number of bits stored by the PredicateVector
|
||||||
static int const kPredicates = kPredicates_;
|
static constexpr int kPredicates = kPredicates_;
|
||||||
|
|
||||||
/// Number of bits stored within each byte of the predicate bit vector
|
/// Number of bits stored within each byte of the predicate bit vector
|
||||||
static int const kPredicatesPerByte = kPredicatesPerByte_;
|
static constexpr int kPredicatesPerByte = kPredicatesPerByte_;
|
||||||
|
|
||||||
/// First bit withing each byte containing predicates
|
/// First bit within each byte containing predicates
|
||||||
static int const kPredicateStart = kPredicateStart_;
|
static constexpr int kPredicateStart = kPredicateStart_;
|
||||||
|
|
||||||
// Make sure no one tries to put more than 8 bits in a byte :)
|
// Make sure no one tries to put more than 8 bits in a byte :)
|
||||||
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
|
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
|
||||||
@ -132,10 +132,13 @@ struct PredicateVector {
|
|||||||
typedef uint32_t Storage;
|
typedef uint32_t Storage;
|
||||||
|
|
||||||
/// Number of bytes needed
|
/// Number of bytes needed
|
||||||
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
||||||
|
|
||||||
/// Number of storage elements needed
|
/// Number of storage elements needed
|
||||||
static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
|
static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
|
||||||
|
|
||||||
|
/// The byte mask corresponding to predicates
|
||||||
|
static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//
|
//
|
||||||
@ -162,6 +165,27 @@ struct PredicateVector {
|
|||||||
bit = byte_offset * 8 + bit_offset + kPredicateStart;
|
bit = byte_offset * 8 + bit_offset + kPredicateStart;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns word mask.
|
||||||
|
CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() {
|
||||||
|
Storage mask(0);
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int byte = 0; byte < sizeof(Storage); ++byte) {
|
||||||
|
mask |= (kByteMask << (byte * 8));
|
||||||
|
}
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns mask of last word.
|
||||||
|
CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() {
|
||||||
|
Storage mask(0);
|
||||||
|
constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage));
|
||||||
|
CUTLASS_PRAGMA_UNROLL
|
||||||
|
for (int byte = 0; byte < count; ++byte) {
|
||||||
|
mask |= (kByteMask << (byte * 8));
|
||||||
|
}
|
||||||
|
return mask;
|
||||||
|
}
|
||||||
|
|
||||||
/// Accesses a given word with optional assertions
|
/// Accesses a given word with optional assertions
|
||||||
CUTLASS_HOST_DEVICE Storage &storage(int word) {
|
CUTLASS_HOST_DEVICE Storage &storage(int word) {
|
||||||
CUTLASS_ASSERT(word < kWordCount);
|
CUTLASS_ASSERT(word < kWordCount);
|
||||||
@ -490,15 +514,14 @@ struct PredicateVector {
|
|||||||
|
|
||||||
/// Returns true if entire predicate array is zero.
|
/// Returns true if entire predicate array is zero.
|
||||||
CUTLASS_HOST_DEVICE bool is_zero() const {
|
CUTLASS_HOST_DEVICE bool is_zero() const {
|
||||||
Storage mask(0);
|
constexpr Storage mask = computeWordMask();
|
||||||
for (int byte = 0; byte < sizeof(Storage); ++byte) {
|
Storage result = 0;
|
||||||
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
|
CUTLASS_PRAGMA_UNROLL
|
||||||
mask |= (byte_mask << (byte * 8));
|
for (int word = 0; word < kWordCount - 1; ++word) {
|
||||||
}
|
result |= (storage(word) & mask);
|
||||||
uint32_t result = 0;
|
|
||||||
for (int word = 0; word < kWordCount; ++word) {
|
|
||||||
result |= storage(word);
|
|
||||||
}
|
}
|
||||||
|
constexpr Storage last_word_mask = computeLastWordMask();
|
||||||
|
result |= (storage(kWordCount - 1) & last_word_mask);
|
||||||
return result == 0;
|
return result == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user