/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include /* This implements a ComposedLayout of the form * InvolutionFn o OffsetPlus o Layout * where the InvolutionFn need not be linear (hence the need for the Offset). * * This ComposedLayout provides similar coordinate-to-index mapping and layout manipulations, * but is not considered a "normal" layout. * For example, this layout provides size() functions, but does not provide stride() functions. * * Furthermore, for known InvolutionFns, this layout attempts to decay itself * to a normal-layout with dynamic or static strides. * This is possible by determining the subdomain of the Involution function * that is identity and testing if the right Layout's codomain is contained * within it. */ namespace cute { // A Layout of non-trivially composable functions: F o I o L template struct ComposedLayout : private cute::tuple // EBO for static layouts { CUTE_HOST_DEVICE constexpr ComposedLayout(InvolutionFn const& fn = {}, IntermediateOffset const& offset = {}, Layout const& layout = {}) : cute::tuple(fn, offset, layout) {} // // Accessors // static constexpr int rank = Layout::rank; CUTE_HOST_DEVICE constexpr decltype(auto) swizzle_fn() const { return get<0>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) offset_fn() const { return get<1>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) layout_fn() const { return get<2>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) layout() const { return *this; } CUTE_HOST_DEVICE constexpr decltype(auto) shape() const { return layout_fn().shape(); } // Doesn't really make sense to ask for the strides of this "layout" CUTE_HOST_DEVICE constexpr decltype(auto) stride() const = delete; // // Mappings // // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) // OR // Slice the layout and return the sublayout (Coord has an Underscore slice op) template CUTE_HOST_DEVICE constexpr auto operator()(Coord const& coord) const { if constexpr (has_underscore::value) { return slice(coord, *this); } else { return swizzle_fn()(to_integral(offset_fn()) + layout_fn()(coord)); // (F o L)(c) } CUTE_GCC_UNREACHABLE; } // Map a 1D linear coordinate to a flat ND logical coordinate template ::value)> CUTE_HOST_DEVICE constexpr auto operator[](Int const& linear_idx) const { return get_flat_coord(linear_idx); } // Convenience function for multi-dimensional coordinates template CUTE_HOST_DEVICE constexpr auto operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { return operator()(make_coord(c0,c1,cs...)); } // // Compose // template CUTE_HOST_DEVICE constexpr auto compose(OtherLayout const& other) const { return composition(*this, other); } template CUTE_HOST_DEVICE constexpr auto compose(Layouts const&... layouts) const { return composition(*this, make_tile(layouts...)); } template CUTE_HOST_DEVICE constexpr auto with_shape(OtherShape const& shape) const { return composition(*this, make_layout(shape)); } template CUTE_HOST_DEVICE constexpr auto with_shape(Shapes const&... shapes) const { return composition(*this, make_layout(make_shape(shapes...))); } // // Tile // template CUTE_HOST_DEVICE constexpr auto tile(OtherLayout const& other) const { return tiled_divide(*this, other); } template CUTE_HOST_DEVICE constexpr auto tile(Layouts const&... layouts) const { return tiled_divide(*this, make_tile(layouts...)); } // // Utility // // // Index to Coordinate // // NOTE Only valid for compact layouts // Return the (hierarchical) ND logical coordinate corresponding to the linear index // @post this->crd2idx(@a result) == idx // @post congruent(@a result, shape()) template ::value)> CUTE_HOST_DEVICE constexpr auto get_hier_coord(IInt const& idx) const { return layout_fn().get_hier_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) } // Return the (flat) ND logical coordinate corresponding to the linear index // @post this->crd2idx(@a result) == idx // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 template ::value)> CUTE_HOST_DEVICE constexpr auto get_flat_coord(IInt const& idx) const { return layout_fn().get_flat_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) } // Return the generalized column-major 1D logical coordinate corresponding to the linear index // @post this->crd2idx(@a result) == idx // @post is_integral::value template ::value)> CUTE_HOST_DEVICE constexpr auto get_1d_coord(IInt const& idx) const { return layout_fn().get_1d_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) } }; template struct is_layout> : true_type {}; template struct is_composed_layout : false_type {}; template struct is_composed_layout> : true_type {}; // // Constructors // template CUTE_HOST_DEVICE constexpr auto make_layout(Swizzle const& sxor) { return composition(sxor, Layout,Int<1>>{}); } template CUTE_HOST_DEVICE constexpr auto make_layout(ComposedLayout const& a, Layout const& b) { return composition(a.swizzle_fn(), a.offset_fn(), make_layout(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto make_layout(Layout const& a, ComposedLayout const& b) { return composition(b.swizzle_fn(), b.offset_fn(), make_layout(a, b.layout_fn())); } namespace detail { template CUTE_HOST_DEVICE constexpr auto transfer_swizzle(Layout const& old_layout, Layout const& new_layout) { // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations // This is accomplished by identifying // S o L :=: S? o L* // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S // Then that active identifier is transformed through the layouts: // L*(L[(P o L)(c*)]) // which is a new swizzle identifier for S?, the new swizzle // Projections of the swizzle layout for composition, P auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] auto layout_only_zy = composition(swizzle_only_zy, old_layout); // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y auto zzz_msk = typename Swizzle::zzz_msk{}; auto yyy_msk = typename Swizzle::yyy_msk{}; auto msk_sft = typename Swizzle::msk_sft{}; auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); // Use this new swizzle identifier to construct the new swizzle for new_layout // (this also makes sure it's a "valid" swizzle that Swizzle can represent) return composition(make_swizzle(), new_layout); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto make_fragment_like(ComposedLayout,Offset,Layout> const& layout) { return detail::transfer_swizzle(layout.layout_fn(), make_fragment_like(layout.layout_fn())); } // // Utilities // // Return the layout of a mode template CUTE_HOST_DEVICE constexpr decltype(auto) layout(ComposedLayout const& clayout) { return composition(clayout.swizzle_fn(), clayout.offset_fn(), layout(clayout.layout_fn())); } // Return the shape of a mode template CUTE_HOST_DEVICE constexpr decltype(auto) shape(ComposedLayout const& layout) { return shape(layout.layout_fn()); } // Doesn't make sense to directly ask for the strides of this "layout" template CUTE_HOST_DEVICE constexpr decltype(auto) stride(ComposedLayout const& layout) = delete; // Return the number of elements in a mode template CUTE_HOST_DEVICE constexpr decltype(auto) size(ComposedLayout const& layout) { return size(layout.layout_fn()); } // Return the number of modes template CUTE_HOST_DEVICE constexpr auto rank(ComposedLayout const& layout) { return rank(layout.layout_fn()); } // Return the depth of the layout template CUTE_HOST_DEVICE constexpr auto depth(ComposedLayout const& layout) { return depth(layout.layout_fn()); } // Return the codomain size of a mode template CUTE_HOST_DEVICE constexpr auto cosize(ComposedLayout const& layout) { return cosize(layout.layout_fn()); } // // Operations to manipulate Layouts like a tuple of pairs // template CUTE_HOST_DEVICE constexpr auto get(ComposedLayout const& a) { return composition(a.swizzle_fn(), a.offset_fn(), get(a.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto take(ComposedLayout const& a) { return composition(a.swizzle_fn(), a.offset_fn(), take(a.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto flatten(ComposedLayout const& a) { return composition(a.swizzle_fn(), a.offset_fn(), flatten(a.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto append(ComposedLayout const& a, X const& x) { return composition(a.swizzle_fn(), a.offset_fn(), append(a.layout_fn(), x)); } template CUTE_HOST_DEVICE constexpr auto group(ComposedLayout const& a) { return composition(a.swizzle_fn(), a.offset_fn(), group(a.layout_fn())); } // // Slice a ComposedLayout // namespace detail { template CUTE_HOST_DEVICE constexpr auto make_swizzle_strides(true_type, IntZ const& Z, IntY const& Y, Offset const& offset, int_sequence) { // Below is an optimized/compressed version of: //return cute::make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); // with knowledge of Swizzle, I... ranges for each B bits, // and the layout won't slice along z-bits that are already set // y\z 0 1 // 0 Z DC // 1 -Z DC return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z << Int{}, -(Z << Int{}))...); } template CUTE_HOST_DEVICE constexpr auto make_swizzle_strides(false_type, IntZ const& Z, IntY const& Y, Offset const& offset, int_sequence) { // Below is an optimized/compressed version of: //return cute::make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); // with knowledge of Swizzle, I... ranges for each B bits, // and the layout won't slice along y-bits that are already set // y\z 0 1 // 0 Y+Z Y-Z // 1 DC DC return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) << Int{}, (Y-Z) << Int{})...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) { if constexpr (all_underscore::value) { // Skip the expensive/complicated attempt to decay to a normal layout and just reshape return cute::make_tuple(composition(layout.swizzle_fn(), layout.offset_fn(), slice(coord, layout.layout_fn())), Int<0>{}); } else { // Projections of the swizzle layout for composition auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); auto swizzle_anti_zy = make_layout(shape(sw), make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); auto swizzle_only_zy = make_layout(shape(sw), make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); // The portion of the layout that is not yet consumed auto sliced_layout = slice(coord, layout.layout_fn()); // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] // (this also tests that shape/stride of layout compose with swizzle) auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); // Determine if any active bits collide under the swizzle auto hit_ZandY = !(swizzle_active_bits & ~layout.swizzle_fn()(swizzle_active_bits)); // The portion of the layout that we are consuming now auto diced_layout = dice(coord, layout.layout_fn()); auto diced_coord = dice(coord, coord); auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); // New swizzle and offset auto swizzle = layout.swizzle_fn(); // offset_only_zy interacts with swizzle and gets accumulated with layout.offset_fn() // being careful about the static/dynamic contributions from diced_layout and diced_coord auto offset_only_zy = layout.offset_fn() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); // offset_anti_zy always gets passed through, no interaction with swizzle auto offset_anti_zy = diced_layout_anti_zy(diced_coord); // If Layout's codomain hits on Y AND Z, then it's not reducible // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal // If Layout's codomain hits on neither Y NOR Z, then it's static-normal // Test the sliced layout for hit_X & hit_Y for potential decay if constexpr (is_constant::value) { // Hits on Y AND Z, so it's not reducible return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); } else { // Misses on Y or Z, so it's static-normal or dynamic-normal // Lowest bit of the Z and Y masks auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); // Construct a (dynamic) layout that we can perform the composition with auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); // Decay to a normal layout with offset return cute::make_tuple(composition(swizzle_layout, sliced_layout), swizzle(to_integral(offset_only_zy)) + offset_anti_zy); } } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto slice(Coord const& coord, ComposedLayout const& layout) { return get<0>(slice_and_offset(coord, layout)); } // // composition // template CUTE_HOST_DEVICE constexpr auto composition(Swizzle const& sxor, Offset const& offset, Layout const& layout) { return ComposedLayout>{sxor, offset, layout}; } template CUTE_HOST_DEVICE constexpr auto composition(Swizzle const& sxor, Offset const& offset, ComposedLayout const& layout) { // Assume disjoint swizzles and offsets for commutivity return composition(composition(sxor,layout.swizzle_fn()), offset ^ layout.offset_fn(), layout.layout_fn()); } // Ignore identity case template CUTE_HOST_DEVICE constexpr auto composition(Swizzle<0,M,S> const&, Int<0> const&, Layout const& layout) { return layout; } template CUTE_HOST_DEVICE constexpr auto composition(Swizzle const& sxor, Layout const& layout) { return composition(sxor, Int<0>{}, layout); } template CUTE_HOST_DEVICE constexpr auto composition(ComposedLayout const& a, LayoutOrTile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), composition(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto composition(Layout const& a, Swizzle const& b) { // Get the Z bits and the Y bits auto active_Y = a(typename Swizzle::yyy_msk{}); auto active_Z = a(typename Swizzle::zzz_msk{}); // Works in simple cases... but could be greatly generalized return composition(make_swizzle(), a); } template CUTE_HOST_DEVICE constexpr auto composition(Layout const& a, ComposedLayout const& b) { CUTE_STATIC_ASSERT_V(b.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); return composition(composition(a, b.swizzle_fn()), b.layout_fn()); } template CUTE_HOST_DEVICE constexpr auto composition(ComposedLayout const& a, ComposedLayout const& b) { auto asb = composition(a.layout_fn(), b); return composition(composition(a.swizzle_fn(),asb.swizzle_fn()), asb.offset_fn(), asb.layout_fn()); } // // complement // template CUTE_HOST_DEVICE constexpr auto complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) { // Assume there is no swizzle component in the complement return complement(layout.layout_fn(), cosize_hi); } template CUTE_HOST_DEVICE constexpr auto complement(ComposedLayout const& layout) { return complement(layout, cosize(layout)); } // // inverse // template CUTE_HOST_DEVICE constexpr auto right_inverse(ComposedLayout const& layout) { CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); return composition(right_inverse(layout.layout_fn()), layout.swizzle_fn()); } template CUTE_HOST_DEVICE constexpr auto left_inverse(ComposedLayout const& layout) { CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); return composition(left_inverse(layout.layout_fn()), layout.swizzle_fn()); } // // Other operations // template CUTE_HOST_DEVICE constexpr auto max_common_vector(ComposedLayout,Offset,SLayout> const& a, Layout const& b) { // This assumes that Offset is in the YZ domain of the Swizzle... return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto max_common_vector(Layout const& a, ComposedLayout,Offset,SLayout> const& b) { return max_common_vector(b, a); } template CUTE_HOST_DEVICE constexpr auto max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, ComposedLayout,Offset1,SLayout1> const& b) { auto result = coalesce(composition(a, right_inverse(b))); if constexpr (is_constant<1, decltype(stride<0>(result.layout_fn()))>::value) { return shape<0>(result); } else { return Int<1>{}; } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto zip(ComposedLayout const& a) { return composition(a.swizzle_fn(), a.offset_fn(), zip(a.layout_fn())); } // Partitions template CUTE_HOST_DEVICE constexpr auto logical_divide(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), logical_divide(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto tile_unzip(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), tile_unzip(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto tiled_divide(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), tiled_divide(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto zipped_divide(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), zipped_divide(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto logical_product(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), logical_product(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto tiled_product(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), tiled_product(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto blocked_product(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), blocked_product(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto raked_product(ComposedLayout const& a, Tile const& b) { return composition(a.swizzle_fn(), a.offset_fn(), raked_product(a.layout_fn(), b)); } template CUTE_HOST_DEVICE constexpr auto tile_to_shape(ComposedLayout const& layout, Shape const& trg_shape, ModeOrder const& ord_shape = {}) { return composition(layout.swizzle_fn(), layout.offset_fn(), tile_to_shape(layout.layout_fn(), trg_shape, ord_shape)); } template CUTE_HOST_DEVICE constexpr auto filter(ComposedLayout const& layout, Shape const& trg_profile) { return composition(layout.swizzle_fn(), layout.offset_fn(), filter(layout.layout_fn(), trg_profile)); } template CUTE_HOST_DEVICE constexpr auto coalesce(ComposedLayout const& layout) { return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto coalesce(ComposedLayout const& layout, Shape const& trg_profile) { return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn(), trg_profile)); } /////////////////////////////////////////////////////////////////////////////// // ComposedLayout as second argument is often more difficult... template CUTE_HOST_DEVICE constexpr auto logical_product(Layout const& block, ComposedLayout,Offset,LayoutT> const& tile) { CUTE_STATIC_ASSERT_V(tile.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); // The new layout -- if swizzle wasn't an issue, this is the result // our goal is to determine a new swizzle for these strides auto new_layout = logical_product(block, tile.layout_fn()); // This is accomplished by identifying // S o L :=: S? o L* // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S // Then that active identifier is transformed through the layouts: // L*(L[(P o L)(c*)]) // which is a new swizzle identifier for S?, the new swizzle // Projections of the swizzle layout for composition, P auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] auto layout_only_zy = composition(swizzle_only_zy, tile.layout_fn()); // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); // Get the Z bit and the Y bits auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) auto new_active_Z = new_layout(Int<0>{}, tile.layout_fn()[active_Z]); auto new_active_Y = new_layout(Int<0>{}, tile.layout_fn()[active_Y]); // Use this new swizzle identifier to construxt the new swizzle for new_layout // (this also makes sure it's a "valid" swizzle that Swizzle can represent) return composition(make_swizzle(), new_layout); } template CUTE_HOST_DEVICE constexpr auto tiled_product(Layout const& block, ComposedLayout const& tile) { /// Avoid swizzle slice auto result = logical_product(block, tile); return composition(result.swizzle_fn(), result.offset_fn(), result.layout_fn()(_, repeat>(_))); } template CUTE_HOST_DEVICE constexpr auto blocked_product(Layout const& block, ComposedLayout const& layout) { constexpr int R = cute::max(rank_v, rank_v); auto padded_block = append(block, Layout<_1,_0>{}); auto padded_layout = append(layout, Layout<_1,_0>{}); auto result = logical_product(padded_block, padded_layout); return composition(result.swizzle_fn(), result.offset_fn(), coalesce(zip(get<0>(result.layout_fn()), get<1>(result.layout_fn())), repeat(Int<1>{}))); } // // Upcast and Downcast // template CUTE_HOST_DEVICE constexpr auto upcast(ComposedLayout const& layout) { return composition(upcast(layout.swizzle_fn()), upcast(layout.offset_fn()), upcast(layout.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto downcast(ComposedLayout const& layout) { return composition(downcast(layout.swizzle_fn()), downcast(layout.offset_fn()), downcast(layout.layout_fn())); } template CUTE_HOST_DEVICE constexpr auto recast(ComposedLayout const& layout) { if constexpr (sizeof(NewType) == sizeof(OldType)) { return layout; } else if constexpr (sizeof(NewType) > sizeof(OldType)) { static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); return upcast(layout); } else if constexpr (sizeof(NewType) < sizeof(OldType)) { static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); return downcast(layout); } CUTE_GCC_UNREACHABLE; } // // Display utilities // template CUTE_HOST_DEVICE void print(ComposedLayout const& layout) { print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) { return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); } #endif } // end namespace cute