cutlass/include/cute/tensor_zip.hpp
Yujia Zhai cc3c29a81a
CUTLASS 3.6.0 (#1850)
* v3.6

* update changelog

* update readme

* fix typo

* fixing typos

* hopper gemm with weight prefetch

---------

Co-authored-by: yuzhai <yuzhai@nvidia.com>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2024-10-09 15:33:27 -04:00

244 lines
8.4 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp> // CUTE_HOST_DEVICE
#include <cute/tensor_impl.hpp> // cute::Tensor
#include <cute/container/tuple.hpp> // cute::tuple
namespace cute
{
// A tuple of Iterators that can be offset asymmetrically
// Note that this only accepts op+(tuple<Index...>) and op[tuple<Index...>]
// where each iterator will be offset by its respective index only.
// READ-ONLY for now until cute::tuple can be constructed with references.
template <class... Iters>
struct ZipIterator
{
using value_type = cute::tuple<iter_value_t<Iters>...>;
using element_type = cute::tuple<iter_element_t<Iters>...>;
// NOTE: cute::tuple does not support constructions with references at the moment.
// Consider fixes and/or an implementation of std::forward_as_tuple.
// For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY.
//using reference = cute::tuple<iter_reference_t<Iters>...>;
using reference = value_type;
ZipIterator() = delete;
CUTE_HOST_DEVICE constexpr
ZipIterator(Iters... iters)
: iters_(iters...)
{}
CUTE_HOST_DEVICE constexpr
ZipIterator(cute::tuple<Iters...> const& iters)
: iters_(iters)
{}
CUTE_HOST_DEVICE constexpr
reference operator*() const {
return cute::apply(iters_, [](auto&&... args) { return reference(*args...); });
}
template <class... Index>
CUTE_HOST_DEVICE constexpr
ZipIterator operator+(cute::tuple<Index...> const& idxs) const {
static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators.");
return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; });
}
template <class... Index>
CUTE_HOST_DEVICE constexpr
reference operator[](cute::tuple<Index...> const& idxs) const {
return *(*this + idxs);
}
cute::tuple<Iters...> iters_;
};
//------------------------------------------------------------------------------
// type traits
template <class... Iters>
struct is_rmem<ZipIterator<Iters...>> : conjunction<is_rmem<Iters>...> {};
template <class... Iters>
struct is_smem<ZipIterator<Iters...>> : conjunction<is_smem<Iters>...> {};
template <class... Iters>
struct is_gmem<ZipIterator<Iters...>> : conjunction<is_gmem<Iters>...> {};
// A tuple of Layouts that operates on each Layout symmetrically
// The Layouts need to have compatible shapes and ranks.
// The ZipLayout presents the intersection of the domain of its component Layouts.
// E.g. all Layouts accept 1D coords and ZipLayout does as well.
// The ZipLayout returns the union of the codomain of its component Layouts.
// E.g. all Layouts return an integer so ZipLayout returns a tuple of integers.
template <class... Layouts>
struct ZipLayout
{
static constexpr int rank = (int(0) | ... | Layouts::rank);
static_assert((is_layout<Layouts>::value && ...), "All template parameters must be layouts");
static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank");
CUTE_HOST_DEVICE constexpr
ZipLayout(Layouts const&... layouts)
: layouts_(layouts...)
{}
CUTE_HOST_DEVICE constexpr
ZipLayout(cute::tuple<Layouts...> const& layouts)
: layouts_(layouts)
{}
template <class Coord>
CUTE_HOST_DEVICE constexpr
auto
operator()(Coord const& coord) const {
if constexpr (has_underscore<Coord>::value) {
return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); }));
} else {
return cute::transform(layouts_, [&] (auto layout) { return layout(coord); });
}
CUTE_GCC_UNREACHABLE;
}
// op() convenience function for multi-dimensional coordinates
template <class Coord0, class Coord1, class... Coords>
CUTE_HOST_DEVICE constexpr
decltype(auto)
operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const {
return operator()(make_coord(c0,c1,cs...));
}
cute::tuple<Layouts...> layouts_;
};
template <class... Layouts>
struct is_layout<ZipLayout<Layouts...>> : true_type {};
//
// make_zip_tensor and unzip_tensor
//
template <class... Engines, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
make_zip_tensor(Tensor<Engines,Layouts> const&... tensors)
{
return make_tensor(ZipIterator(tensors.data()...),
ZipLayout(tensors.layout()...));
}
template <class Engine, class Layout>
CUTE_HOST_DEVICE constexpr
auto
unzip_tensor(Tensor<Engine,Layout> const& tensor)
{
return cute::transform(tensor.data().iters_, tensor.layout().layouts_,
[](auto iter, auto layout) { return make_tensor(iter, layout); });
}
//
// Utilities
//
template <int... Is, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
rank(ZipLayout<Layouts...> const& layouts)
{
return rank<Is...>(get<0>(layouts.layouts_));
}
template <int... Is, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
size(ZipLayout<Layouts...> const& layouts)
{
return size<Is...>(get<0>(layouts.layouts_));
}
//
// Manipulation
//
// Extend each component layout to rank-N by appending Layout @a x.
template <int N, class... Layouts, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
append(ZipLayout<Layouts...> const& layouts,
Layout<ShapeX,StrideX> const& x = {})
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append<N>(t, x); }));
}
// Extend each component layout to rank-N by prepending Layout @a x.
template <int N, class... Layouts, class ShapeX = _1, class StrideX = _0>
CUTE_HOST_DEVICE constexpr
auto
prepend(ZipLayout<Layouts...> const& layouts,
Layout<ShapeX,StrideX> const& x = {})
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend<N>(t, x); }));
}
template <class... Layouts, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
logical_divide(ZipLayout<Layouts...> const& layouts,
Tiler const& tiler)
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); }));
}
template <class... Layouts, class Tiler>
CUTE_HOST_DEVICE constexpr
auto
zipped_divide(ZipLayout<Layouts...> const& layouts,
Tiler const& tiler)
{
return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); }));
}
// Return <SlicedZipLayout, ZipOffsets> by calling slice_and_offset and all component layouts.
template <class Coord, class... Layouts>
CUTE_HOST_DEVICE constexpr
auto
slice_and_offset(Coord const& c, ZipLayout<Layouts...> const& layouts)
{
auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); }));
return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result));
}
} // end namespace cute