610 lines
17 KiB
C++
610 lines
17 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include <cute/config.hpp>
|
|
|
|
#include <cute/layout.hpp>
|
|
|
|
/* This implements a ComposedLayout of the form
|
|
* LayoutA o Offset o LayoutB
|
|
* and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB.
|
|
* For example, then the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB).
|
|
*
|
|
* This ComposedLayout provides similar functionality to Layout including tiling, partitioning,
|
|
* coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout.
|
|
* For example, this layout provides shape() and size() functions, but does not provide stride() functions.
|
|
* Mostly, the similar functionality is accomplished by applying each operation to LayoutB only
|
|
* as LayoutB defines the domain.
|
|
*/
|
|
|
|
namespace cute
|
|
{
|
|
|
|
// A Layout of non-trivially composable functions: F o I o L
|
|
template <class LayoutA, class Offset, class LayoutB>
|
|
struct ComposedLayout : private cute::tuple<LayoutA, Offset, LayoutB> // EBO for static layouts
|
|
{
|
|
CUTE_HOST_DEVICE constexpr
|
|
ComposedLayout(LayoutA const& layoutA = {},
|
|
Offset const& offset = {},
|
|
LayoutB const& layoutB = {})
|
|
: cute::tuple<LayoutA, Offset, LayoutB>(layoutA, offset, layoutB)
|
|
{}
|
|
|
|
//
|
|
// Accessors
|
|
//
|
|
|
|
static constexpr int rank = LayoutB::rank;
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
layout_a() const {
|
|
return get<0>(static_cast<cute::tuple<LayoutA, Offset, LayoutB> const&>(*this));
|
|
}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
offset() const {
|
|
return get<1>(static_cast<cute::tuple<LayoutA, Offset, LayoutB> const&>(*this));
|
|
}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
layout_b() const {
|
|
return get<2>(static_cast<cute::tuple<LayoutA, Offset, LayoutB> const&>(*this));
|
|
}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
layout() const {
|
|
return *this;
|
|
}
|
|
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
shape() const {
|
|
return layout_b().shape();
|
|
}
|
|
|
|
// Doesn't really make sense to ask for the strides of this "layout"
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
stride() const = delete;
|
|
|
|
//
|
|
// Mappings
|
|
//
|
|
|
|
// Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
|
|
// OR
|
|
// Slice the layout and return the sublayout (Coord has an Underscore slice op)
|
|
template <class Coord>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator()(Coord const& coord) const {
|
|
if constexpr (has_underscore<Coord>::value) {
|
|
return slice(coord, *this);
|
|
} else {
|
|
return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c)
|
|
}
|
|
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
// Convenience function for multi-dimensional coordinates
|
|
template <class Coord0, class Coord1, class... Coords>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
|
|
return operator()(make_coord(c0,c1,cs...));
|
|
}
|
|
|
|
//
|
|
// Compose
|
|
//
|
|
|
|
template <class OtherLayout>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
compose(OtherLayout const& other) const {
|
|
return composition(*this, other);
|
|
}
|
|
|
|
template <class... Layouts>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
compose(Layouts const&... layouts) const {
|
|
return composition(*this, make_tile(layouts...));
|
|
}
|
|
|
|
template <class OtherShape>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
with_shape(OtherShape const& shape) const {
|
|
return composition(*this, make_layout(shape));
|
|
}
|
|
|
|
template <class... Shapes>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
with_shape(Shapes const&... shapes) const {
|
|
return composition(*this, make_layout(make_shape(shapes...)));
|
|
}
|
|
|
|
//
|
|
// Tile
|
|
//
|
|
|
|
template <class OtherLayout>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tile(OtherLayout const& other) const {
|
|
return tiled_divide(*this, other);
|
|
}
|
|
|
|
template <class... Layouts>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tile(Layouts const&... layouts) const {
|
|
return tiled_divide(*this, make_tile(layouts...));
|
|
}
|
|
};
|
|
|
|
template <class A, class O, class B>
|
|
struct is_layout<ComposedLayout<A,O,B>> : true_type {};
|
|
|
|
template <class T>
|
|
struct is_composed_layout : false_type {};
|
|
template <class A, class O, class B>
|
|
struct is_composed_layout<ComposedLayout<A,O,B>> : true_type {};
|
|
|
|
//
|
|
// Constructors
|
|
//
|
|
|
|
template <class LayoutA, class Offset, class LayoutB>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
make_composed_layout(LayoutA const& layoutA,
|
|
Offset const& offset,
|
|
LayoutB const& layoutB)
|
|
{
|
|
return ComposedLayout<LayoutA, Offset, LayoutB>{layoutA, offset, layoutB};
|
|
}
|
|
|
|
//
|
|
// Utilities
|
|
//
|
|
|
|
// Return the layout of a mode
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
layout(ComposedLayout<A,O,B> const& clayout)
|
|
{
|
|
return composition(clayout.layout_a(), clayout.offset(), layout<Is...>(clayout.layout_b()));
|
|
}
|
|
|
|
// Return the shape of a mode
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
shape(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return shape<Is...>(layout.layout_b());
|
|
}
|
|
|
|
// Doesn't make sense to directly ask for the strides of this "layout"
|
|
template <int... Is, class Fn, class O, class Layout>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
stride(ComposedLayout<Fn,O,Layout> const& layout) = delete;
|
|
|
|
// Return the number of elements in a mode
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
decltype(auto)
|
|
size(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return size<Is...>(layout.layout_b());
|
|
}
|
|
|
|
// Return the number of modes
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
rank(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return rank<Is...>(layout.layout_b());
|
|
}
|
|
|
|
// Return the depth of the layout
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
depth(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return depth<Is...>(layout.layout_b());
|
|
}
|
|
|
|
// Return the codomain size of a mode
|
|
template <int... Is, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
cosize(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return cosize<Is...>(layout.layout_b());
|
|
}
|
|
|
|
//
|
|
// Operations to manipulate Layouts like a tuple of pairs
|
|
//
|
|
|
|
template <size_t I, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
get(ComposedLayout<A,O,B> const& a)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), get<I>(a.layout_b()));
|
|
}
|
|
|
|
template <int Begin, int End, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
take(ComposedLayout<A,O,B> const& a)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), take<Begin,End>(a.layout_b()));
|
|
}
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
flatten(ComposedLayout<A,O,B> const& a)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), flatten(a.layout_b()));
|
|
}
|
|
|
|
template <int N, class A, class O, class B, class X>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
append(ComposedLayout<A,O,B> const& a, X const& x)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), append<N>(a.layout_b(), x));
|
|
}
|
|
|
|
template <int Begin, int End, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
group(ComposedLayout<A,O,B> const& a)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), group<Begin,End>(a.layout_b()));
|
|
}
|
|
|
|
//
|
|
// Slice a ComposedLayout
|
|
//
|
|
|
|
template <class Coord, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
slice_and_offset(Coord const& coord, ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
auto [slice, offset] = slice_and_offset(coord, layout.layout_b());
|
|
return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{});
|
|
}
|
|
|
|
template <class Coord, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
slice(Coord const& coord, ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return get<0>(slice_and_offset(coord, layout));
|
|
}
|
|
|
|
// Compute a pointer offset and (potentially modified) layout from a coordinate
|
|
// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated
|
|
template <class Coord, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
domain_offset(Coord const& coord, ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{});
|
|
}
|
|
|
|
//
|
|
// composition
|
|
//
|
|
|
|
template <class LayoutA,
|
|
class Offset,
|
|
class LayoutB>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
composition(LayoutA const& layoutA,
|
|
Offset const& offset,
|
|
LayoutB const& layoutB)
|
|
{
|
|
return ComposedLayout<LayoutA, Offset, LayoutB>{layoutA, offset, layoutB};
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class LayoutOrTile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
composition(ComposedLayout<A,O,B> const& a,
|
|
LayoutOrTile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b));
|
|
}
|
|
|
|
template <class ShapeA, class StrideA,
|
|
class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
composition(Layout<ShapeA,StrideA> const& a,
|
|
ComposedLayout<A,O,B> const& b)
|
|
{
|
|
CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0.");
|
|
|
|
return composition(composition(a, b.layout_a()), b.layout_b());
|
|
}
|
|
|
|
//
|
|
// complement
|
|
//
|
|
|
|
template <class A, class O, class B, class CoSizeHi>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
complement(ComposedLayout<A,O,B> const& layout, CoSizeHi const& cosize_hi)
|
|
{
|
|
return complement(layout.layout_b(), cosize_hi);
|
|
}
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
complement(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return complement(layout, cosize(layout));
|
|
}
|
|
|
|
//
|
|
// inverse
|
|
//
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
right_inverse(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a()));
|
|
}
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
left_inverse(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a()));
|
|
}
|
|
|
|
//
|
|
// Other operations
|
|
//
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
zip(ComposedLayout<A,O,B> const& a)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), zip(a.layout_b()));
|
|
}
|
|
|
|
// Partitions
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
logical_divide(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tile_unzip(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tiled_divide(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
zipped_divide(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
logical_product(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tiled_product(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
blocked_product(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Tile>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
raked_product(ComposedLayout<A,O,B> const& a,
|
|
Tile const& b)
|
|
{
|
|
return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Shape, class ModeOrder = GenColMajor>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
tile_to_shape(ComposedLayout<A,O,B> const& layout,
|
|
Shape const& trg_shape,
|
|
ModeOrder const& ord_shape = {})
|
|
{
|
|
return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape));
|
|
}
|
|
|
|
template <class A, class O, class B,
|
|
class Shape>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
filter(ComposedLayout<A,O,B> const& layout, Shape const& trg_profile)
|
|
{
|
|
return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile));
|
|
}
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
coalesce(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b()));
|
|
}
|
|
|
|
template <class A, class O, class B, class Shape>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
coalesce(ComposedLayout<A,O,B> const& layout, Shape const& trg_profile)
|
|
{
|
|
return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile));
|
|
}
|
|
|
|
//
|
|
// Upcast and Downcast
|
|
//
|
|
|
|
template <int N, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
upcast(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return composition(upcast<N>(layout.layout_a()), upcast<N>(layout.offset()), upcast<N>(layout.layout_b()));
|
|
}
|
|
|
|
template <int N, class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
downcast(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return composition(downcast<N>(layout.layout_a()), downcast<N>(layout.offset()), downcast<N>(layout.layout_b()));
|
|
}
|
|
|
|
template <class OldType, class NewType,
|
|
class A, class O, class B>
|
|
CUTE_HOST_DEVICE constexpr
|
|
auto
|
|
recast(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
if constexpr (sizeof(NewType) == sizeof(OldType)) {
|
|
return layout;
|
|
} else if constexpr (sizeof(NewType) > sizeof(OldType)) {
|
|
static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType");
|
|
return upcast<sizeof(NewType)/sizeof(OldType)>(layout);
|
|
} else if constexpr (sizeof(NewType) < sizeof(OldType)) {
|
|
static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType");
|
|
return downcast<sizeof(OldType)/sizeof(NewType)>(layout);
|
|
}
|
|
|
|
CUTE_GCC_UNREACHABLE;
|
|
}
|
|
|
|
//
|
|
// Display utilities
|
|
//
|
|
|
|
template <class A, class O, class B>
|
|
CUTE_HOST_DEVICE void print(ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b());
|
|
}
|
|
|
|
#if !defined(__CUDACC_RTC__)
|
|
template <class A, class O, class B>
|
|
CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout<A,O,B> const& layout)
|
|
{
|
|
return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b();
|
|
}
|
|
#endif
|
|
|
|
} // end namespace cute
|