Fix is_zero (#1147)

* Fix is_zero

* Use constexpr

* Add CUTLASS_PRAGMA_UNROLL to loops

* Avoid if branches in is_zero
This commit is contained in:
cyyever 2023-10-24 00:09:37 +08:00 committed by GitHub
parent fb10fa5308
commit 7a7796afae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -106,7 +106,7 @@ tile_traits_concept and a \ref predicate_vector_concept.
/// Statically sized array of bits implementing @concept{predicate_vector_concept}. /// Statically sized array of bits implementing @concept{predicate_vector_concept}.
template < template <
/// Number of predicates conatined in predicate vector /// Number of predicates contained in predicate vector
int kPredicates_, int kPredicates_,
/// Number of predicates contained in each byte of internal storage /// Number of predicates contained in each byte of internal storage
int kPredicatesPerByte_ = 4, int kPredicatesPerByte_ = 4,
@ -114,13 +114,13 @@ template <
int kPredicateStart_ = 0> int kPredicateStart_ = 0>
struct PredicateVector { struct PredicateVector {
/// Number of bits stored by the PredicateVector /// Number of bits stored by the PredicateVector
static int const kPredicates = kPredicates_; static constexpr int kPredicates = kPredicates_;
/// Number of bits stored within each byte of the predicate bit vector /// Number of bits stored within each byte of the predicate bit vector
static int const kPredicatesPerByte = kPredicatesPerByte_; static constexpr int kPredicatesPerByte = kPredicatesPerByte_;
/// First bit withing each byte containing predicates /// First bit within each byte containing predicates
static int const kPredicateStart = kPredicateStart_; static constexpr int kPredicateStart = kPredicateStart_;
// Make sure no one tries to put more than 8 bits in a byte :) // Make sure no one tries to put more than 8 bits in a byte :)
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
@ -132,10 +132,13 @@ struct PredicateVector {
typedef uint32_t Storage; typedef uint32_t Storage;
/// Number of bytes needed /// Number of bytes needed
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
/// Number of storage elements needed /// Number of storage elements needed
static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
/// The byte mask corresponding to predicates
static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
private: private:
// //
@ -162,6 +165,27 @@ struct PredicateVector {
bit = byte_offset * 8 + bit_offset + kPredicateStart; bit = byte_offset * 8 + bit_offset + kPredicateStart;
} }
/// Returns word mask.
CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() {
Storage mask(0);
CUTLASS_PRAGMA_UNROLL
for (int byte = 0; byte < sizeof(Storage); ++byte) {
mask |= (kByteMask << (byte * 8));
}
return mask;
}
/// Returns mask of last word.
CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() {
Storage mask(0);
constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage));
CUTLASS_PRAGMA_UNROLL
for (int byte = 0; byte < count; ++byte) {
mask |= (kByteMask << (byte * 8));
}
return mask;
}
/// Accesses a given word with optional assertions /// Accesses a given word with optional assertions
CUTLASS_HOST_DEVICE Storage &storage(int word) { CUTLASS_HOST_DEVICE Storage &storage(int word) {
CUTLASS_ASSERT(word < kWordCount); CUTLASS_ASSERT(word < kWordCount);
@ -490,15 +514,14 @@ struct PredicateVector {
/// Returns true if entire predicate array is zero. /// Returns true if entire predicate array is zero.
CUTLASS_HOST_DEVICE bool is_zero() const { CUTLASS_HOST_DEVICE bool is_zero() const {
Storage mask(0); constexpr Storage mask = computeWordMask();
for (int byte = 0; byte < sizeof(Storage); ++byte) { Storage result = 0;
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); CUTLASS_PRAGMA_UNROLL
mask |= (byte_mask << (byte * 8)); for (int word = 0; word < kWordCount - 1; ++word) {
} result |= (storage(word) & mask);
uint32_t result = 0;
for (int word = 0; word < kWordCount; ++word) {
result |= storage(word);
} }
constexpr Storage last_word_mask = computeLastWordMask();
result |= (storage(kWordCount - 1) & last_word_mask);
return result == 0; return result == 0;
} }