Fix is_zero (#1147)
* Fix is_zero * Use constexpr * Add CUTLASS_PRAGMA_UNROLL to loops * Avoid if branches in is_zero
This commit is contained in:
parent
fb10fa5308
commit
7a7796afae
@ -106,7 +106,7 @@ tile_traits_concept and a \ref predicate_vector_concept.
|
||||
|
||||
/// Statically sized array of bits implementing @concept{predicate_vector_concept}.
|
||||
template <
|
||||
/// Number of predicates conatined in predicate vector
|
||||
/// Number of predicates contained in predicate vector
|
||||
int kPredicates_,
|
||||
/// Number of predicates contained in each byte of internal storage
|
||||
int kPredicatesPerByte_ = 4,
|
||||
@ -114,13 +114,13 @@ template <
|
||||
int kPredicateStart_ = 0>
|
||||
struct PredicateVector {
|
||||
/// Number of bits stored by the PredicateVector
|
||||
static int const kPredicates = kPredicates_;
|
||||
static constexpr int kPredicates = kPredicates_;
|
||||
|
||||
/// Number of bits stored within each byte of the predicate bit vector
|
||||
static int const kPredicatesPerByte = kPredicatesPerByte_;
|
||||
static constexpr int kPredicatesPerByte = kPredicatesPerByte_;
|
||||
|
||||
/// First bit withing each byte containing predicates
|
||||
static int const kPredicateStart = kPredicateStart_;
|
||||
/// First bit within each byte containing predicates
|
||||
static constexpr int kPredicateStart = kPredicateStart_;
|
||||
|
||||
// Make sure no one tries to put more than 8 bits in a byte :)
|
||||
static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte");
|
||||
@ -132,10 +132,13 @@ struct PredicateVector {
|
||||
typedef uint32_t Storage;
|
||||
|
||||
/// Number of bytes needed
|
||||
static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
||||
static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;
|
||||
|
||||
/// Number of storage elements needed
|
||||
static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
|
||||
static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage));
|
||||
|
||||
/// The byte mask corresponding to predicates
|
||||
static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
|
||||
|
||||
private:
|
||||
//
|
||||
@ -162,6 +165,27 @@ struct PredicateVector {
|
||||
bit = byte_offset * 8 + bit_offset + kPredicateStart;
|
||||
}
|
||||
|
||||
/// Returns word mask.
|
||||
CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() {
|
||||
Storage mask(0);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int byte = 0; byte < sizeof(Storage); ++byte) {
|
||||
mask |= (kByteMask << (byte * 8));
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
/// Returns mask of last word.
|
||||
CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() {
|
||||
Storage mask(0);
|
||||
constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int byte = 0; byte < count; ++byte) {
|
||||
mask |= (kByteMask << (byte * 8));
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
/// Accesses a given word with optional assertions
|
||||
CUTLASS_HOST_DEVICE Storage &storage(int word) {
|
||||
CUTLASS_ASSERT(word < kWordCount);
|
||||
@ -490,15 +514,14 @@ struct PredicateVector {
|
||||
|
||||
/// Returns true if entire predicate array is zero.
|
||||
CUTLASS_HOST_DEVICE bool is_zero() const {
|
||||
Storage mask(0);
|
||||
for (int byte = 0; byte < sizeof(Storage); ++byte) {
|
||||
Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
|
||||
mask |= (byte_mask << (byte * 8));
|
||||
}
|
||||
uint32_t result = 0;
|
||||
for (int word = 0; word < kWordCount; ++word) {
|
||||
result |= storage(word);
|
||||
constexpr Storage mask = computeWordMask();
|
||||
Storage result = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int word = 0; word < kWordCount - 1; ++word) {
|
||||
result |= (storage(word) & mask);
|
||||
}
|
||||
constexpr Storage last_word_mask = computeLastWordMask();
|
||||
result |= (storage(kWordCount - 1) & last_word_mask);
|
||||
return result == 0;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user