/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include // CUTE_HOST_DEVICE #include // cute::Tensor #include // cute::tuple namespace cute { // A tuple of Iterators that can be offset asymmetrically // Note that this only accepts op+(tuple) and op[tuple] // where each iterator will be offset by its respective index only. // READ-ONLY for now until cute::tuple can be constructed with references. template struct ZipIterator { using value_type = cute::tuple...>; using element_type = cute::tuple...>; // NOTE: cute::tuple does not support constructions with references at the moment. // Consider fixes and/or an implementation of std::forward_as_tuple. // For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY. //using reference = cute::tuple...>; using reference = value_type; ZipIterator() = delete; CUTE_HOST_DEVICE constexpr ZipIterator(Iters... iters) : iters_(iters...) {} CUTE_HOST_DEVICE constexpr ZipIterator(cute::tuple const& iters) : iters_(iters) {} CUTE_HOST_DEVICE constexpr reference operator*() const { return cute::apply(iters_, [](auto&&... args) { return reference(*args...); }); } template CUTE_HOST_DEVICE constexpr ZipIterator operator+(cute::tuple const& idxs) const { static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); } template CUTE_HOST_DEVICE constexpr reference operator[](cute::tuple const& idxs) const { return *(*this + idxs); } cute::tuple iters_; }; //------------------------------------------------------------------------------ // type traits template struct is_rmem> : conjunction...> {}; template struct is_smem> : conjunction...> {}; template struct is_gmem> : conjunction...> {}; // A tuple of Layouts that operates on each Layout symmetrically // The Layouts need to have compatible shapes and ranks. // The ZipLayout presents the intersection of the domain of its component Layouts. // E.g. all Layouts accept 1D coords and ZipLayout does as well. // The ZipLayout returns the union of the codomain of its component Layouts. // E.g. all Layouts return an integer so ZipLayout returns a tuple of integers. template struct ZipLayout { static constexpr int rank = (int(0) | ... | Layouts::rank); static_assert((is_layout::value && ...), "All template parameters must be layouts"); static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank"); CUTE_HOST_DEVICE constexpr ZipLayout(Layouts const&... layouts) : layouts_(layouts...) {} CUTE_HOST_DEVICE constexpr ZipLayout(cute::tuple const& layouts) : layouts_(layouts) {} template CUTE_HOST_DEVICE constexpr auto operator()(Coord const& coord) const { if constexpr (has_underscore::value) { return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); })); } else { return cute::transform(layouts_, [&] (auto layout) { return layout(coord); }); } CUTE_GCC_UNREACHABLE; } // op() convenience function for multi-dimensional coordinates template CUTE_HOST_DEVICE constexpr decltype(auto) operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { return operator()(make_coord(c0,c1,cs...)); } cute::tuple layouts_; }; template struct is_layout> : true_type {}; // // make_zip_tensor and unzip_tensor // template CUTE_HOST_DEVICE constexpr auto make_zip_tensor(Tensor const&... tensors) { return make_tensor(ZipIterator(tensors.data()...), ZipLayout(tensors.layout()...)); } template CUTE_HOST_DEVICE constexpr auto unzip_tensor(Tensor const& tensor) { return cute::transform(tensor.data().iters_, tensor.layout().layouts_, [](auto iter, auto layout) { return make_tensor(iter, layout); }); } // // Utilities // template CUTE_HOST_DEVICE constexpr auto rank(ZipLayout const& layouts) { return rank(get<0>(layouts.layouts_)); } template CUTE_HOST_DEVICE constexpr auto size(ZipLayout const& layouts) { return size(get<0>(layouts.layouts_)); } // // Manipulation // // Extend each component layout to rank-N by appending Layout @a x. template CUTE_HOST_DEVICE constexpr auto append(ZipLayout const& layouts, Layout const& x = {}) { return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append(t, x); })); } // Extend each component layout to rank-N by prepending Layout @a x. template CUTE_HOST_DEVICE constexpr auto prepend(ZipLayout const& layouts, Layout const& x = {}) { return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend(t, x); })); } template CUTE_HOST_DEVICE constexpr auto logical_divide(ZipLayout const& layouts, Tiler const& tiler) { return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); })); } template CUTE_HOST_DEVICE constexpr auto zipped_divide(ZipLayout const& layouts, Tiler const& tiler) { return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); })); } // Return by calling slice_and_offset and all component layouts. template CUTE_HOST_DEVICE constexpr auto slice_and_offset(Coord const& c, ZipLayout const& layouts) { auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); })); return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); } } // end namespace cute