/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include /** Common algorithms on (hierarchical) tuples */ /** Style choice: * Forward params [using static_cast(.)] for const/non-const/ref/non-ref args * but don't bother forwarding functions as ref-qualified member fns are extremely rare */ namespace cute { // // Apply (Unpack) // (t, f) => f(t_0,t_1,...,t_n) // namespace detail { template CUTE_HOST_DEVICE constexpr auto apply(T&& t, F&& f, seq) { return f(get(static_cast(t))...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto apply(T&& t, F&& f) { return detail::apply(static_cast(t), f, tuple_seq{}); } // // Transform Apply // (t, f, g) => g(f(t_0),f(t_1),...) // namespace detail { template CUTE_HOST_DEVICE constexpr auto tapply(T&& t, F&& f, G&& g, seq) { return g(f(get(static_cast(t)))...); } template CUTE_HOST_DEVICE constexpr auto tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) { return g(f(get(static_cast(t0)), get(static_cast(t1)))...); } template CUTE_HOST_DEVICE constexpr auto tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) { return g(f(get(static_cast(t0)), get(static_cast(t1)), get(static_cast(t2)))...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto transform_apply(T&& t, F&& f, G&& g) { return detail::tapply(static_cast(t), f, g, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) { return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) { return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); } // // For Each // (t, f) => f(t_0),f(t_1),...,f(t_n) // template CUTE_HOST_DEVICE constexpr void for_each(T&& t, F&& f) { detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto for_each_leaf(T&& t, F&& f) { if constexpr (is_tuple>::value) { return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); } else { return f(static_cast(t)); } CUTE_GCC_UNREACHABLE; } // // Transform // (t, f) => (f(t_0),f(t_1),...,f(t_n)) // template CUTE_HOST_DEVICE constexpr auto transform(T const& t, F&& f) { return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, F&& f) { static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) { static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } template CUTE_HOST_DEVICE constexpr auto transform_leaf(T const& t, F&& f) { if constexpr (is_tuple::value) { return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); } else { return f(t); } CUTE_GCC_UNREACHABLE; } // // find and find_if // namespace detail { template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f, seq<>) { return cute::integral_constant::value>{}; } template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f, seq) { if constexpr (decltype(f(get(t)))::value) { return cute::integral_constant{}; } else { return find_if(t, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f) { if constexpr (is_tuple::value) { return detail::find_if(t, f, tuple_seq{}); } else { return cute::integral_constant{}; } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto find(T const& t, X const& x) { return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false } template auto none_of(T const& t, F&& f) { return cute::integral_constant::value>{}; } template auto all_of(T const& t, F&& f) { auto not_f = [&](auto const& a) { return !f(a); }; return cute::integral_constant::value>{}; } template auto any_of(T const& t, F&& f) { return cute::integral_constant{}; } // // Filter // (t, f) => // template CUTE_HOST_DEVICE constexpr auto filter_tuple(T const& t, F&& f) { return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } template CUTE_HOST_DEVICE constexpr auto filter_tuple(T0 const& t0, T1 const& t1, F&& f) { return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } // // Fold (Reduce, Accumulate) // (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) // namespace detail { // This impl compiles much faster than cute::apply and variadic args template CUTE_HOST_DEVICE constexpr decltype(auto) fold(T&& t, V&& v, F&& f, seq<>) { return static_cast(v); } template CUTE_HOST_DEVICE constexpr decltype(auto) fold(T&& t, V&& v, F&& f, seq) { if constexpr (sizeof...(Is) == 0) { return f(static_cast(v), get(static_cast(t))); } else { return fold(static_cast(t), f(static_cast(v), get(static_cast(t))), f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto fold(T&& t, V&& v, F&& f) { if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), static_cast(v), f, tuple_seq{}); } else { return f(static_cast(v), static_cast(t)); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr decltype(auto) fold_first(T&& t, F&& f) { if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), get<0>(static_cast(t)), f, make_range<1,std::tuple_size>::value>{}); } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // // front, back, take, unwrap // // Get the first non-tuple element in a hierarchical tuple template CUTE_HOST_DEVICE constexpr decltype(auto) front(T&& t) { if constexpr (is_tuple>::value) { return front(get<0>(static_cast(t))); } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // Get the last non-tuple element in a hierarchical tuple template CUTE_HOST_DEVICE constexpr decltype(auto) back(T&& t) { if constexpr (is_tuple>::value) { constexpr int N = tuple_size>::value; return back(get(static_cast(t))); } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // Takes the elements in the range [B,E) template CUTE_HOST_DEVICE constexpr auto take(T const& t) { return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); } // Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple template CUTE_HOST_DEVICE constexpr auto unwrap(T const& t) { if constexpr (is_tuple::value) { if constexpr (tuple_size::value == 1) { return unwrap(get<0>(t)); } else { return t; } } else { return t; } CUTE_GCC_UNREACHABLE; } // // Flatten a hierarchical tuple to a tuple of depth one. // template CUTE_HOST_DEVICE constexpr auto flatten_to_tuple(T const& t) { if constexpr (is_tuple::value) { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); } else { return cute::make_tuple(t); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto flatten(T const& t) { if constexpr (is_tuple::value) { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); } else { return t; } CUTE_GCC_UNREACHABLE; } // // insert and remove and replace // namespace detail { // Shortcut around tuple_cat for common insert/remove/repeat cases template CUTE_HOST_DEVICE constexpr auto construct(T const& t, X const& x, seq, seq, seq) { return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); } } // end namespace detail // Insert x into the Nth position of the tuple template CUTE_HOST_DEVICE constexpr auto insert(T const& t, X const& x) { return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); } // Remove the Nth element of the tuple template CUTE_HOST_DEVICE constexpr auto remove(T const& t) { return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); } // Replace the Nth element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace(T const& t, X const& x) { return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); } // Replace the first element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace_front(T const& t, X const& x) { if constexpr (is_tuple::value) { return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); } else { return x; } CUTE_GCC_UNREACHABLE; } // Replace the last element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace_back(T const& t, X const& x) { if constexpr (is_tuple::value) { return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); } else { return x; } CUTE_GCC_UNREACHABLE; } // // Make a tuple of Xs of tuple_size N // template CUTE_HOST_DEVICE constexpr auto repeat(X const& x) { return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); } // // Make a tuple of Xs the same profile as tuple // template CUTE_HOST_DEVICE constexpr auto repeat_like(T const& t, X const& x) { if constexpr (is_tuple::value) { return transform(t, [&](auto const& a) { return repeat_like(a,x); }); } else { return x; } CUTE_GCC_UNREACHABLE; } // Group the elements [B,E) of a T into a single element // e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) // => T<_1,_2,T<_3,_4>,_5,_6>{} template CUTE_HOST_DEVICE constexpr auto group(T const& t) { return detail::construct(t, take(t), make_seq{}, seq<0>{}, make_range::value>{}); } // // Extend a T to rank N by appending/prepending an element // template CUTE_HOST_DEVICE constexpr auto append(T const& a, X const& x) { if constexpr (is_tuple::value) { if constexpr (N == tuple_size::value) { return a; } else { static_assert(N > tuple_size::value); return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); } } else { if constexpr (N == 1) { return a; } else { return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); } } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto append(T const& a, X const& x) { if constexpr (is_tuple::value) { return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); } else { return cute::make_tuple(a, x); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto prepend(T const& a, X const& x) { if constexpr (is_tuple::value) { if constexpr (N == tuple_size::value) { return a; } else { static_assert(N > tuple_size::value); return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); } } else { if constexpr (N == 1) { return a; } else { static_assert(N > 1); return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); } } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto prepend(T const& a, X const& x) { if constexpr (is_tuple::value) { return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); } else { return cute::make_tuple(x, a); } CUTE_GCC_UNREACHABLE; } // // Inclusive scan (prefix sum) // namespace detail { template CUTE_HOST_DEVICE constexpr auto iscan(T const& t, V const& v, F&& f, seq) { // Apply the function to v and the element at I auto v_next = f(v, get(t)); // Replace I with v_next auto t_next = replace(t, v_next); #if 0 std::cout << "ISCAN i" << I << std::endl; std::cout << " t " << t << std::endl; std::cout << " i " << v << std::endl; std::cout << " f(i,t) " << v_next << std::endl; std::cout << " t_n " << t_next << std::endl; #endif if constexpr (sizeof...(Is) == 0) { return t_next; } else { return iscan(t_next, v_next, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto iscan(T const& t, V const& v, F&& f) { return detail::iscan(t, v, f, tuple_seq{}); } // // Exclusive scan (prefix sum) // namespace detail { template CUTE_HOST_DEVICE constexpr auto escan(T const& t, V const& v, F&& f, seq) { if constexpr (sizeof...(Is) == 0) { // Replace I with v return replace(t, v); } else { // Apply the function to v and the element at I auto v_next = f(v, get(t)); // Replace I with v auto t_next = replace(t, v); #if 0 std::cout << "ESCAN i" << I << std::endl; std::cout << " t " << t << std::endl; std::cout << " i " << v << std::endl; std::cout << " f(i,t) " << v_next << std::endl; std::cout << " t_n " << t_next << std::endl; #endif // Recurse return escan(t_next, v_next, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto escan(T const& t, V const& v, F&& f) { return detail::escan(t, v, f, tuple_seq{}); } // // Zip (Transpose) // // Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input // to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output namespace detail { template CUTE_HOST_DEVICE constexpr auto zip_(T const& t, seq) { return cute::make_tuple(get(get(t))...); } template CUTE_HOST_DEVICE constexpr auto zip(T const& t, seq, seq) { static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); return cute::make_tuple(detail::zip_(t, seq{})...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto zip(T const& t) { if constexpr (is_tuple::value) { if constexpr (is_tuple>::value) { return detail::zip(t, tuple_seq{}, tuple_seq>{}); } else { return cute::make_tuple(t); } } else { return t; } CUTE_GCC_UNREACHABLE; } // Convenient to pass them in separately template CUTE_HOST_DEVICE constexpr auto zip(T0 const& t0, T1 const& t1, Ts const&... ts) { return zip(cute::make_tuple(t0, t1, ts...)); } // // zip2_by -- A guided zip for rank-2 tuples // Take a tuple like ((A,a),((B,b),(C,c)),d) // and produce a tuple ((A,(B,C)),(a,(b,c),d)) // where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) // namespace detail { template CUTE_HOST_DEVICE constexpr auto zip2_by(T const& t, TG const& guide, seq, seq) { // zip2_by produces the modes like ((A,a),(B,b),...) auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) return cute::make_tuple(cute::make_tuple(get(split)...), cute::make_tuple(get(split)..., get(t)...)); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto zip2_by(T const& t, TG const& guide) { if constexpr (is_tuple::value) { constexpr int TR = tuple_size::value; constexpr int GR = tuple_size::value; static_assert(TR >= GR, "Mismatched ranks"); return detail::zip2_by(t, guide, make_range< 0, GR>{}, make_range{}); } else { static_assert(tuple_size::value == 2, "Mismatched ranks"); return t; } CUTE_GCC_UNREACHABLE; } } // end namespace cute