From 7a7796afaeb1986bb50b44ab35fefaca75394528 Mon Sep 17 00:00:00 2001 From: cyyever Date: Tue, 24 Oct 2023 00:09:37 +0800 Subject: [PATCH] Fix is_zero (#1147) * Fix is_zero * Use constexpr * Add CUTLASS_PRAGMA_UNROLL to loops * Avoid if branches in is_zero --- include/cutlass/predicate_vector.h | 53 +++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index d1582257..b5ffe74e 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -106,7 +106,7 @@ tile_traits_concept and a \ref predicate_vector_concept. /// Statically sized array of bits implementing @concept{predicate_vector_concept}. template < - /// Number of predicates conatined in predicate vector + /// Number of predicates contained in predicate vector int kPredicates_, /// Number of predicates contained in each byte of internal storage int kPredicatesPerByte_ = 4, @@ -114,13 +114,13 @@ template < int kPredicateStart_ = 0> struct PredicateVector { /// Number of bits stored by the PredicateVector - static int const kPredicates = kPredicates_; + static constexpr int kPredicates = kPredicates_; /// Number of bits stored within each byte of the predicate bit vector - static int const kPredicatesPerByte = kPredicatesPerByte_; + static constexpr int kPredicatesPerByte = kPredicatesPerByte_; - /// First bit withing each byte containing predicates - static int const kPredicateStart = kPredicateStart_; + /// First bit within each byte containing predicates + static constexpr int kPredicateStart = kPredicateStart_; // Make sure no one tries to put more than 8 bits in a byte :) static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); @@ -132,10 +132,13 @@ struct PredicateVector { typedef uint32_t Storage; /// Number of bytes needed - static int const kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; + static constexpr int kBytes = (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte; /// Number of storage elements needed - static int const kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); + static constexpr int kWordCount = (kBytes + int(sizeof(Storage)) - 1) / int(sizeof(Storage)); + + /// The byte mask corresponding to predicates + static constexpr Storage kByteMask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); private: // @@ -162,6 +165,27 @@ struct PredicateVector { bit = byte_offset * 8 + bit_offset + kPredicateStart; } + /// Returns word mask. + CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() { + Storage mask(0); + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < sizeof(Storage); ++byte) { + mask |= (kByteMask << (byte * 8)); + } + return mask; + } + + /// Returns mask of last word. + CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() { + Storage mask(0); + constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage)); + CUTLASS_PRAGMA_UNROLL + for (int byte = 0; byte < count; ++byte) { + mask |= (kByteMask << (byte * 8)); + } + return mask; + } + /// Accesses a given word with optional assertions CUTLASS_HOST_DEVICE Storage &storage(int word) { CUTLASS_ASSERT(word < kWordCount); @@ -490,15 +514,14 @@ struct PredicateVector { /// Returns true if entire predicate array is zero. CUTLASS_HOST_DEVICE bool is_zero() const { - Storage mask(0); - for (int byte = 0; byte < sizeof(Storage); ++byte) { - Storage byte_mask = (((1 << kPredicatesPerByte) - 1) << kPredicateStart); - mask |= (byte_mask << (byte * 8)); - } - uint32_t result = 0; - for (int word = 0; word < kWordCount; ++word) { - result |= storage(word); + constexpr Storage mask = computeWordMask(); + Storage result = 0; + CUTLASS_PRAGMA_UNROLL + for (int word = 0; word < kWordCount - 1; ++word) { + result |= (storage(word) & mask); } + constexpr Storage last_word_mask = computeLastWordMask(); + result |= (storage(kWordCount - 1) & last_word_mask); return result == 0; }