| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | /*! \file
 | 
					
						
							|  |  |  |     \brief Statically sized array of elements that accommodates subbyte trivial types | 
					
						
							|  |  |  |            in a packed storage. | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <cute/config.hpp>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | #include <cute/numeric/numeric_types.hpp>
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #include <cute/numeric/integral_constant.hpp>
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace cute | 
					
						
							|  |  |  | { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // Underlying subbyte storage type
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v<T> <=   8), uint8_t, | 
					
						
							|  |  |  |                                conditional_t<(cute::sizeof_bits_v<T> <=  16), uint16_t, | 
					
						
							|  |  |  |                                conditional_t<(cute::sizeof_bits_v<T> <=  32), uint32_t, | 
					
						
							|  |  |  |                                conditional_t<(cute::sizeof_bits_v<T> <=  64), uint64_t, | 
					
						
							|  |  |  |                                conditional_t<(cute::sizeof_bits_v<T> <= 128), uint128_t, | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |                                T>>>>>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | template <class T> struct subbyte_iterator; | 
					
						
							|  |  |  | template <class, class> struct swizzle_ptr; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // subbyte_reference
 | 
					
						
							|  |  |  | //   Proxy object for sub-byte element references
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | struct subbyte_reference | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   // Iterator Element type (const or non-const)
 | 
					
						
							|  |  |  |   using element_type = T; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Iterator Value type without type qualifier.
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   using value_type   = remove_cv_t<T>; | 
					
						
							|  |  |  |   // Storage type (const or non-const)
 | 
					
						
							|  |  |  |   using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   static_assert(sizeof_bits_v<element_type> <= sizeof_bits_v<storage_type>, | 
					
						
							|  |  |  |                 "Size of Element must not be greater than Storage."); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | private: | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Bitmask for covering one item
 | 
					
						
							|  |  |  |   static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v<storage_type> - sizeof_bits_v<element_type>)); | 
					
						
							|  |  |  |   // Flag for fast branching on straddled elements
 | 
					
						
							|  |  |  |   static constexpr bool is_storage_unaligned = ((sizeof_bits_v<storage_type> % sizeof_bits_v<element_type>) != 0); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-30 04:21:31 +08:00
										 |  |  |   friend struct subbyte_iterator<T>; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Pointer to storage element
 | 
					
						
							|  |  |  |   storage_type* ptr_ = nullptr; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Bit index of value_type starting position within storage_type element.
 | 
					
						
							|  |  |  |   // RI: 0 <= idx_ < sizeof_bit<storage_type>
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   uint8_t idx_ = 0; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Ctor
 | 
					
						
							|  |  |  |   template <class PointerType> | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast<storage_type*>(ptr)), idx_(idx) {} | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Copy Ctor
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   subbyte_reference(subbyte_reference const& other) { | 
					
						
							|  |  |  |     *this = element_type(other); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Copy Assignment
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   subbyte_reference& operator=(subbyte_reference const& other) { | 
					
						
							|  |  |  |     return *this = element_type(other); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Assignment
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   template <class T_ = element_type> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   enable_if_t<!is_const_v<T_>, subbyte_reference&> operator=(element_type x) | 
					
						
							|  |  |  |   { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     static_assert(is_same_v<T_, element_type>, "Do not specify template arguments!"); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     storage_type item = (reinterpret_cast<storage_type const&>(x) & BitMask); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // Update the current storage element
 | 
					
						
							|  |  |  |     storage_type bit_mask_0 = storage_type(BitMask << idx_); | 
					
						
							|  |  |  |     ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic)
 | 
					
						
							|  |  |  |     if (is_storage_unaligned && idx_ + sizeof_bits_v<value_type> > sizeof_bits_v<storage_type>) { | 
					
						
							|  |  |  |       uint8_t straddle_bits = uint8_t(sizeof_bits_v<storage_type> - idx_); | 
					
						
							|  |  |  |       storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); | 
					
						
							|  |  |  |       // Update the next storage element
 | 
					
						
							|  |  |  |       ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return *this; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Comparison of referenced values
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() <  y.get(); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() >  y.get(); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Value
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   element_type get() const | 
					
						
							|  |  |  |   { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     if constexpr (is_same_v<bool, value_type>) {      // Extract to bool -- potentially faster impl
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |       return bool((*ptr_) & (BitMask << idx_)); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     } else {                                          // Extract to element_type
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |       // Extract from the current storage element
 | 
					
						
							|  |  |  |       auto item = storage_type((ptr_[0] >> idx_) & BitMask); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic)
 | 
					
						
							|  |  |  |       if (is_storage_unaligned && idx_ + sizeof_bits_v<value_type> > sizeof_bits_v<storage_type>) { | 
					
						
							|  |  |  |         uint8_t straddle_bits = uint8_t(sizeof_bits_v<storage_type> - idx_); | 
					
						
							|  |  |  |         storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); | 
					
						
							|  |  |  |         // Extract from the next storage element
 | 
					
						
							|  |  |  |         item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); | 
					
						
							|  |  |  |       } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       return reinterpret_cast<element_type&>(item); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Extract to type element_type
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   operator element_type() const { | 
					
						
							|  |  |  |     return get(); | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // Address
 | 
					
						
							|  |  |  |   subbyte_iterator<T> operator&() const { | 
					
						
							|  |  |  |     return {ptr_, idx_}; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 03:33:27 +08:00
										 |  |  | template <class T> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE | 
					
						
							|  |  |  | void | 
					
						
							|  |  |  | print(subbyte_reference<T> ref) { | 
					
						
							|  |  |  |   cute::print(ref.get()); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE | 
					
						
							|  |  |  | void | 
					
						
							|  |  |  | pretty_print(subbyte_reference<T> ref) { | 
					
						
							|  |  |  |   cute::pretty_print(ref.get()); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // subbyte_iterator
 | 
					
						
							|  |  |  | //   Random-access iterator over subbyte references
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | struct subbyte_iterator | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   // Iterator Element type (const or non-const)
 | 
					
						
							|  |  |  |   using element_type = T; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Iterator Value type without type qualifier.
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   using value_type   = remove_cv_t<T>; | 
					
						
							|  |  |  |   // Storage type (const or non-const)
 | 
					
						
							|  |  |  |   using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | 
					
						
							|  |  |  |   // Reference proxy type
 | 
					
						
							|  |  |  |   using reference = subbyte_reference<element_type>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   static_assert(sizeof_bits_v<element_type> <= sizeof_bits_v<storage_type>, | 
					
						
							|  |  |  |                 "Size of Element must not be greater than Storage."); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | private: | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-12-30 04:21:31 +08:00
										 |  |  |   template <class, class> friend struct swizzle_ptr; | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   template <class U> friend CUTE_HOST_DEVICE constexpr U* raw_pointer_cast(subbyte_iterator<U> const&); | 
					
						
							|  |  |  |   template <class N, class U> friend CUTE_HOST_DEVICE constexpr auto recast_ptr(subbyte_iterator<U> const&); | 
					
						
							|  |  |  |   template <class U> friend CUTE_HOST_DEVICE void print(subbyte_iterator<U> const&); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Pointer to storage element
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   storage_type* ptr_; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Bit index of value_type starting position within storage_type element.
 | 
					
						
							|  |  |  |   // RI: 0 <= idx_ < sizeof_bit<storage_type>
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   uint8_t idx_; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | public: | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   // Default Ctor
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   subbyte_iterator() : ptr_{nullptr}, idx_{0} {}; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   // Ctor
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   template <class PointerType> | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast<storage_type*>(ptr)), idx_(idx) { } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   reference operator*() const { | 
					
						
							|  |  |  |     return reference(ptr_, idx_); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator& operator+=(uint64_t k) { | 
					
						
							|  |  |  |     k = sizeof_bits_v<value_type> * k + idx_; | 
					
						
							|  |  |  |     ptr_ += k / sizeof_bits_v<storage_type>; | 
					
						
							|  |  |  |     idx_  = k % sizeof_bits_v<storage_type>; | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return *this; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator operator+(uint64_t k) const { | 
					
						
							|  |  |  |     return subbyte_iterator(ptr_, idx_) += k; | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   reference operator[](uint64_t k) const { | 
					
						
							|  |  |  |     return *(*this + k); | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator& operator++() { | 
					
						
							|  |  |  |     idx_ += sizeof_bits_v<value_type>; | 
					
						
							|  |  |  |     if (idx_ >= sizeof_bits_v<storage_type>) { | 
					
						
							|  |  |  |       ++ptr_; | 
					
						
							|  |  |  |       idx_ -= sizeof_bits_v<storage_type>; | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return *this; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator operator++(int) { | 
					
						
							|  |  |  |     subbyte_iterator ret(*this); | 
					
						
							|  |  |  |     ++(*this); | 
					
						
							|  |  |  |     return ret; | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator& operator--() { | 
					
						
							|  |  |  |     if (idx_ >= sizeof_bits_v<value_type>) { | 
					
						
							|  |  |  |       idx_ -= sizeof_bits_v<value_type>; | 
					
						
							|  |  |  |     } else { | 
					
						
							|  |  |  |       --ptr_; | 
					
						
							|  |  |  |       idx_ += sizeof_bits_v<storage_type> - sizeof_bits_v<value_type>; | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  |     return *this; | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   subbyte_iterator operator--(int) { | 
					
						
							|  |  |  |     subbyte_iterator ret(*this); | 
					
						
							|  |  |  |     --(*this); | 
					
						
							|  |  |  |     return ret; | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; | 
					
						
							|  |  |  |   } | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { | 
					
						
							|  |  |  |     return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y <  x); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return  (y <  x); } | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr friend | 
					
						
							|  |  |  |   bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x <  y); } | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  | }; | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  | // Conversion to raw pointer with loss of subbyte index
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | T* | 
					
						
							|  |  |  | raw_pointer_cast(subbyte_iterator<T> const& x) { | 
					
						
							|  |  |  |   assert(x.idx_ == 0); | 
					
						
							|  |  |  |   return reinterpret_cast<T*>(x.ptr_); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  | // Conversion to NewT_ with possible loss of subbyte index
 | 
					
						
							|  |  |  | template <class NewT_, class T> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | auto | 
					
						
							|  |  |  | recast_ptr(subbyte_iterator<T> const& x) { | 
					
						
							|  |  |  |   using NewT = conditional_t<(is_const_v<T>), NewT_ const, NewT_>; | 
					
						
							|  |  |  |   if constexpr (cute::is_subbyte_v<NewT>) {       // Making subbyte_iter, preserve the subbyte idx
 | 
					
						
							|  |  |  |     return subbyte_iterator<NewT>(x.ptr_, x.idx_); | 
					
						
							|  |  |  |   } else {                                       // Not subbyte, assume/assert subbyte idx 0
 | 
					
						
							|  |  |  |     return reinterpret_cast<NewT*>(raw_pointer_cast(x)); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  |   CUTE_GCC_UNREACHABLE; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  | template <class T> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE void | 
					
						
							|  |  |  | print(subbyte_iterator<T> const& x) { | 
					
						
							|  |  |  |   printf("subptr[%db](%p.%u)", int(sizeof_bits_v<T>), x.ptr_, x.idx_); | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // array_subbyte
 | 
					
						
							|  |  |  | //   Statically sized array for non-byte-aligned data types
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | template <class T, size_t N> | 
					
						
							|  |  |  | struct array_subbyte | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using element_type    = T; | 
					
						
							|  |  |  |   using value_type      = remove_cv_t<T>; | 
					
						
							|  |  |  |   using pointer         = element_type*; | 
					
						
							|  |  |  |   using const_pointer   = element_type const*; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   using size_type       = size_t; | 
					
						
							|  |  |  |   using difference_type = ptrdiff_t; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   //
 | 
					
						
							|  |  |  |   // References
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   using reference       = subbyte_reference<element_type>; | 
					
						
							|  |  |  |   using const_reference = subbyte_reference<element_type const>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   //
 | 
					
						
							|  |  |  |   // Iterators
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   using iterator        = subbyte_iterator<element_type>; | 
					
						
							|  |  |  |   using const_iterator  = subbyte_iterator<element_type const>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Storage type (const or non-const)
 | 
					
						
							|  |  |  |   using storage_type = conditional_t<(is_const_v<T>), subbyte_storage_type_t<T> const, subbyte_storage_type_t<T>>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   static_assert(sizeof_bits_v<storage_type> % 8 == 0, "Storage type is not supported"); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | private: | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |   // Number of storage elements, ceil_div
 | 
					
						
							|  |  |  |   static constexpr size_type StorageElements = (N * sizeof_bits_v<value_type> + sizeof_bits_v<storage_type> - 1) / sizeof_bits_v<storage_type>; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Internal storage
 | 
					
						
							|  |  |  |   storage_type storage[StorageElements]; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | public: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   size_type size() const { | 
					
						
							|  |  |  |     return N; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   size_type max_size() const { | 
					
						
							|  |  |  |     return N; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   bool empty() const { | 
					
						
							|  |  |  |     return !N; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |   // Efficient clear method
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   void clear() { | 
					
						
							|  |  |  |     CUTE_UNROLL | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     for (size_type i = 0; i < StorageElements; ++i) { | 
					
						
							|  |  |  |       storage[i] = storage_type(0); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   void fill(T const& value) { | 
					
						
							|  |  |  |     CUTE_UNROLL | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     for (size_type i = 0; i < N; ++i) { | 
					
						
							|  |  |  |       at(i) = value; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     } | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   reference at(size_type pos) { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return iterator(storage)[pos]; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_reference at(size_type pos) const { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return const_iterator(storage)[pos]; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   reference operator[](size_type pos) { | 
					
						
							|  |  |  |     return at(pos); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_reference operator[](size_type pos) const { | 
					
						
							|  |  |  |     return at(pos); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   reference front() { | 
					
						
							|  |  |  |     return at(0); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_reference front() const { | 
					
						
							|  |  |  |     return at(0); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   reference back() { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return at(N-1); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_reference back() const { | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     return at(N-1); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   // In analogy to std::vector<bool>::data(), these functions are deleted to prevent bugs.
 | 
					
						
							|  |  |  |   // Instead, prefer
 | 
					
						
							|  |  |  |   //   auto* data = raw_pointer_cast(my_subbyte_array.begin());
 | 
					
						
							|  |  |  |   // where the type of auto* is implementation-defined and
 | 
					
						
							|  |  |  |   // with the knowledge that [data, data + my_subbyte_array.size()) may not be a valid range.
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   pointer data() = delete; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							| 
									
										
										
										
											2024-07-29 20:46:24 +08:00
										 |  |  |   const_pointer data() const = delete; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   iterator begin() { | 
					
						
							|  |  |  |     return iterator(storage); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_iterator begin() const { | 
					
						
							|  |  |  |     return const_iterator(storage); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_iterator cbegin() const { | 
					
						
							|  |  |  |     return begin(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   iterator end() { | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     return iterator(storage) + N; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_iterator end() const { | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     return const_iterator(storage) + N; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  |   const_iterator cend() const { | 
					
						
							|  |  |  |     return end(); | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  |   // Comparison operators
 | 
					
						
							|  |  |  |   //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Operators
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | void clear(array_subbyte<T,N>& a) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   a.clear(); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | void fill(array_subbyte<T,N>& a, T const& value) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   a.fill(value); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace cute
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Specialize tuple-related functionality for cute::array_subbyte
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #if defined(__CUDACC_RTC__)
 | 
					
						
							|  |  |  | #include <cuda/std/tuple>
 | 
					
						
							|  |  |  | #else
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | #include <tuple>
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace cute | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <size_t I, class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | T& get(array_subbyte<T,N>& a) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   static_assert(I < N, "Index out of range"); | 
					
						
							|  |  |  |   return a[I]; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <size_t I, class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | T const& get(array_subbyte<T,N> const& a) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   static_assert(I < N, "Index out of range"); | 
					
						
							|  |  |  |   return a[I]; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <size_t I, class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | T&& get(array_subbyte<T,N>&& a) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   static_assert(I < N, "Index out of range"); | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  |   return cute::move(a[I]); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // end namespace cute
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | namespace CUTE_STL_NAMESPACE | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | template <class T> | 
					
						
							|  |  |  | struct is_reference<cute::subbyte_reference<T>> | 
					
						
							|  |  |  |     : CUTE_STL_NAMESPACE::true_type | 
					
						
							|  |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <class T, size_t N> | 
					
						
							|  |  |  | struct tuple_size<cute::array_subbyte<T,N>> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <size_t I, class T, size_t N> | 
					
						
							|  |  |  | struct tuple_element<I, cute::array_subbyte<T,N>> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using type = T; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T, size_t N> | 
					
						
							|  |  |  | struct tuple_size<const cute::array_subbyte<T,N>> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <size_t I, class T, size_t N> | 
					
						
							|  |  |  | struct tuple_element<I, const cute::array_subbyte<T,N>> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using type = T; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // end namespace CUTE_STL_NAMESPACE
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | namespace std | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #if defined(__CUDACC_RTC__)
 | 
					
						
							|  |  |  | template <class... _Tp> | 
					
						
							|  |  |  | struct tuple_size; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | template <size_t _Ip, class... _Tp> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | struct tuple_element; | 
					
						
							|  |  |  | #endif
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | struct tuple_size<cute::array_subbyte<T,N>> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <size_t I, class T, size_t N> | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | struct tuple_element<I, cute::array_subbyte<T,N>> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using type = T; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | template <class T, size_t N> | 
					
						
							|  |  |  | struct tuple_size<const cute::array_subbyte<T,N>> | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |     : CUTE_STL_NAMESPACE::integral_constant<size_t, N> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <size_t I, class T, size_t N> | 
					
						
							|  |  |  | struct tuple_element<I, const cute::array_subbyte<T,N>> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using type = T; | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | } // end namespace std
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD
 |