Fix is_zero (#1147)

* Fix is_zero

* Use constexpr

* Add CUTLASS_PRAGMA_UNROLL to loops

* Avoid if branches in is_zero
This commit is contained in:
cyyever 2023-10-24 00:09:37 +08:00 committed by GitHub
parent fb10fa5308
commit 7a7796afae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -106,7 +106,7 @@ tile_traits_concept and a \ref predicate_vector_concept.
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
template <
/// Number of predicates conatined in predicate vector
/// Number of predicates contained in predicate vector
int kPredicates_,
/// Number of predicates contained in each byte of internal storage
int kPredicatesPerByte_ = 4,
@ -114,13 +114,13 @@ template <
int kPredicateStart_ = 0>
struct PredicateVector {
/// Number of bits stored by the PredicateVector
static int const kPredicates = kPredicates_;
static constexpr int kPredicates = kPredicates_;
/// Number of bits stored within each byte of the predicate bit vector
static int const kPredicatesPerByte = kPredicatesPerByte_;
static constexpr int kPredicatesPerByte = kPredicatesPerByte_;
/// First bit withing each byte containing predicates
static int const kPredicateStart = kPredicateStart_;
/// First bit within each byte containing predicates
static constexpr int kPredicateStart = kPredicateStart_;
// Make sure no one tries to put more than 8 bits in a byte :)
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
@ -132,10 +132,13 @@ struct PredicateVector {
typedef uint32_t Storage;
/// Number of bytes needed
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
/// Number of storage elements needed
static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
/// The byte mask corresponding to predicates
static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
private:
//
@ -162,6 +165,27 @@ struct PredicateVector {
bit = byte_offset * 8 + bit_offset + kPredicateStart;
}
/// Returns word mask.
CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() {
Storage mask(0);
CUTLASS_PRAGMA_UNROLL
for (int byte = 0; byte < sizeof(Storage); ++byte) {
mask |= (kByteMask << (byte * 8));
}
return mask;
}
/// Returns mask of last word.
CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() {
Storage mask(0);
constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage));
CUTLASS_PRAGMA_UNROLL
for (int byte = 0; byte < count; ++byte) {
mask |= (kByteMask << (byte * 8));
}
return mask;
}
/// Accesses a given word with optional assertions
CUTLASS_HOST_DEVICE Storage &storage(int word) {
CUTLASS_ASSERT(word < kWordCount);
@ -490,15 +514,14 @@ struct PredicateVector {
/// Returns true if entire predicate array is zero.
CUTLASS_HOST_DEVICE bool is_zero() const {
Storage mask(0);
for (int byte = 0; byte < sizeof(Storage); ++byte) {
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
mask |= (byte_mask << (byte * 8));
}
uint32_t result = 0;
for (int word = 0; word < kWordCount; ++word) {
result |= storage(word);
constexpr Storage mask = computeWordMask();
Storage result = 0;
CUTLASS_PRAGMA_UNROLL
for (int word = 0; word < kWordCount - 1; ++word) {
result |= (storage(word) & mask);
}
constexpr Storage last_word_mask = computeLastWordMask();
result |= (storage(kWordCount - 1) & last_word_mask);
return result == 0;
}