| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							| 
									
										
										
										
											2024-01-17 03:37:22 +08:00
										 |  |  |  * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-10-10 03:33:27 +08:00
										 |  |  | #include <cute/config.hpp>                     // CUTE_INLINE_CONSTANT, CUTE_HOST_DEVICE
 | 
					
						
							|  |  |  | #include <cute/container/tuple.hpp>            // cute::is_tuple
 | 
					
						
							|  |  |  | #include <cute/numeric/integral_constant.hpp>  // cute::false_type, cute::true_type
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace cute | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // For slicing
 | 
					
						
							|  |  |  | struct Underscore : Int<0> {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | CUTE_INLINE_CONSTANT Underscore _; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-20 05:51:04 +08:00
										 |  |  | // Convenient alias
 | 
					
						
							|  |  |  | using X = Underscore; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | // Treat Underscore as an integral like integral_constant
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct is_integral<Underscore> : true_type {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class T> | 
					
						
							|  |  |  | struct is_underscore : false_type {}; | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct is_underscore<Underscore> : true_type {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Tuple trait for detecting static member element
 | 
					
						
							|  |  |  | template <class Tuple, class Elem, class Enable = void> | 
					
						
							|  |  |  | struct has_elem : false_type {}; | 
					
						
							|  |  |  | template <class Elem> | 
					
						
							|  |  |  | struct has_elem<Elem, Elem> : true_type {}; | 
					
						
							|  |  |  | template <class Tuple, class Elem> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | struct has_elem<Tuple, Elem, enable_if_t<is_tuple<Tuple>::value> > | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     : has_elem<Tuple, Elem, tuple_seq<Tuple> > {}; | 
					
						
							|  |  |  | template <class Tuple, class Elem, int... Is> | 
					
						
							|  |  |  | struct has_elem<Tuple, Elem, seq<Is...>> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     : disjunction<has_elem<tuple_element_t<Is, Tuple>, Elem>...> {}; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | // Tuple trait for detecting static member element
 | 
					
						
							|  |  |  | template <class Tuple, class Elem, class Enable = void> | 
					
						
							|  |  |  | struct all_elem : false_type {}; | 
					
						
							|  |  |  | template <class Elem> | 
					
						
							|  |  |  | struct all_elem<Elem, Elem> : true_type {}; | 
					
						
							|  |  |  | template <class Tuple, class Elem> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | struct all_elem<Tuple, Elem, enable_if_t<is_tuple<Tuple>::value> > | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |     : all_elem<Tuple, Elem, tuple_seq<Tuple> > {}; | 
					
						
							|  |  |  | template <class Tuple, class Elem, int... Is> | 
					
						
							|  |  |  | struct all_elem<Tuple, Elem, seq<Is...>> | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     : conjunction<all_elem<tuple_element_t<Is, Tuple>, Elem>...> {}; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | // Tuple trait for detecting Underscore member
 | 
					
						
							|  |  |  | template <class Tuple> | 
					
						
							|  |  |  | using has_underscore = has_elem<Tuple, Underscore>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class Tuple> | 
					
						
							|  |  |  | using all_underscore = all_elem<Tuple, Underscore>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class Tuple> | 
					
						
							|  |  |  | using has_int1 = has_elem<Tuple, Int<1>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <class Tuple> | 
					
						
							|  |  |  | using has_int0 = has_elem<Tuple, Int<0>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Slice keeps only the elements of Tuple B that are paired with an Underscore
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | template <class A, class B> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | auto | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | lift_slice(A const& a, B const& b) | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   if constexpr (is_tuple<A>::value) { | 
					
						
							|  |  |  |     static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks"); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_slice(x,y); }); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } else if constexpr (is_underscore<A>::value) { | 
					
						
							|  |  |  |     return cute::tuple<B>{b}; | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     return cute::tuple<>{}; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_GCC_UNREACHABLE; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | } // end namespace detail
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Entry point overrides the lifting so that slice(_,b) == b
 | 
					
						
							|  |  |  | template <class A, class B> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | auto | 
					
						
							|  |  |  | slice(A const& a, B const& b) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   if constexpr (is_tuple<A>::value) { | 
					
						
							|  |  |  |     static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks"); | 
					
						
							|  |  |  |     return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_slice(x,y); }); | 
					
						
							|  |  |  |   } else if constexpr (is_underscore<A>::value) { | 
					
						
							|  |  |  |     return b; | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     return cute::tuple<>{}; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_GCC_UNREACHABLE; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // Dice keeps only the elements of Tuple B that are paired with an Int
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | template <class A, class B> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | auto | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | lift_dice(A const& a, B const& b) | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | { | 
					
						
							|  |  |  |   if constexpr (is_tuple<A>::value) { | 
					
						
							|  |  |  |     static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks"); | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  |     return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_dice(x,y); }); | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   } else if constexpr (is_underscore<A>::value) { | 
					
						
							|  |  |  |     return cute::tuple<>{}; | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     return cute::tuple<B>{b}; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_GCC_UNREACHABLE; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-02 23:09:05 +08:00
										 |  |  | } // end namespace detail
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Entry point overrides the lifting so that dice(1,b) == b
 | 
					
						
							|  |  |  | template <class A, class B> | 
					
						
							|  |  |  | CUTE_HOST_DEVICE constexpr | 
					
						
							|  |  |  | auto | 
					
						
							|  |  |  | dice(A const& a, B const& b) | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   if constexpr (is_tuple<A>::value) { | 
					
						
							|  |  |  |     static_assert(tuple_size<A>::value == tuple_size<B>::value, "Mismatched Ranks"); | 
					
						
							|  |  |  |     return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_dice(x,y); }); | 
					
						
							|  |  |  |   } else if constexpr (is_underscore<A>::value) { | 
					
						
							|  |  |  |     return cute::tuple<>{}; | 
					
						
							|  |  |  |   } else { | 
					
						
							|  |  |  |     return b; | 
					
						
							|  |  |  |   } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   CUTE_GCC_UNREACHABLE; | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | //
 | 
					
						
							|  |  |  | // Display utilities
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | CUTE_HOST_DEVICE void print(Underscore const&) { | 
					
						
							|  |  |  |   printf("_"); | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #if !defined(__CUDACC_RTC__)
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { | 
					
						
							|  |  |  |   return os << "_"; | 
					
						
							|  |  |  | } | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #endif
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | } // end namespace cute
 |