cutlass/include/cutlass/detail/layout.hpp
2024-01-16 14:37:22 -05:00

264 lines
8.4 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/tensor.h"
#include "cute/layout.hpp"
#include "cute/arch/copy_sm90_tma.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::detail {
////////////////////////////////////////////////////////////////////////////////////////////////////
// For each cutlass::layout, provides its corresponding cute stride types, 64b by default
template <class L>
struct TagToStrideA {
using type = L;
};
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::RowMajor> {
using type = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using tag = layout::RowMajor;
};
// Maps to modes [M, K, L]
template <>
struct TagToStrideA<layout::ColumnMajor> {
using type = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using tag = layout::ColumnMajor;
};
template <class L>
struct TagToStrideB {
using type = L;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::RowMajor> {
using type = cute::Stride<cute::Int<1>, int64_t, int64_t>;
using tag = layout::RowMajor;
};
// Maps to modes [N, K, L]
template <>
struct TagToStrideB<layout::ColumnMajor> {
using type = cute::Stride<int64_t, cute::Int<1>, int64_t>;
using tag = layout::ColumnMajor;
};
// Maps to modes [M, N, L]
template <class LayoutTag>
struct TagToStrideC : TagToStrideA<LayoutTag> { };
// Convenience aliases
template<class LayoutTag>
using TagToStrideA_t = typename TagToStrideA<LayoutTag>::type;
template<class LayoutTag>
using TagToStrideB_t = typename TagToStrideB<LayoutTag>::type;
template<class LayoutTag>
using TagToStrideC_t = typename TagToStrideC<LayoutTag>::type;
////////////////////////////////////////////////////////////////////////////////////////////////////
// For 2.x compatibility APIs, provide stride->layout tag mappers
template<int ModeIndex, class Stride>
constexpr bool
is_major(Stride = {}) {
// Account for stride types with and without batch mode and batch modes with static zero stride
return cute::is_constant<1, decltype(cute::front(cute::get<ModeIndex>(Stride{})))>::value;
}
// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices
template<class StrideA>
constexpr
auto
stride_to_layout_tag_A() {
if constexpr (is_major<0, StrideA>()) { // M major
return layout::ColumnMajor{};
}
else { // K major
return layout::RowMajor{};
}
CUTE_GCC_UNREACHABLE;
}
template<class StrideB>
constexpr
auto
stride_to_layout_tag_B() {
if constexpr (is_major<0, StrideB>()) { // N major
return layout::RowMajor{};
}
else { // K major
return layout::ColumnMajor{};
}
CUTE_GCC_UNREACHABLE;
}
template<class StrideC>
constexpr
auto
stride_to_layout_tag_C() {
if constexpr (is_major<0, StrideC>()) { // M major
return layout::ColumnMajor{};
}
else { // N major
return layout::RowMajor{};
}
CUTE_GCC_UNREACHABLE;
}
// Utilities to map Stride back on to their corresponding layout tags
template <class S>
struct StrideToLayoutTagA {
using type = decltype(detail::stride_to_layout_tag_A<S>());
};
template <class S>
struct StrideToLayoutTagB {
using type = decltype(detail::stride_to_layout_tag_B<S>());
};
template <class S>
struct StrideToLayoutTagC {
using type = decltype(detail::stride_to_layout_tag_C<S>());
};
// Convenience aliases
template<class S>
using StrideToLayoutTagA_t = typename StrideToLayoutTagA<S>::type;
template<class S>
using StrideToLayoutTagB_t = typename StrideToLayoutTagB<S>::type;
template<class S>
using StrideToLayoutTagC_t = typename StrideToLayoutTagC<S>::type;
////////////////////////////////////////////////////////////////////////////////////////////////////
// Inspects a tiled copy and whether its copy engine is TMA or not
template<class GmemTiledCopy>
constexpr bool is_tma_copy_engine() {
if constexpr (cute::is_void_v<GmemTiledCopy>) {
return false;
}
else {
if constexpr ( cute::is_base_of_v<cute::SM90_TMA_LOAD, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_MULTICAST, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_LOAD_IM2COL_MULTICAST, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_STORE, GmemTiledCopy>
|| cute::is_base_of_v<cute::SM90_TMA_STORE_IM2COL, GmemTiledCopy>
) {
return true;
}
}
return false;
}
// Inspects a TiledCopy and returns its alignment in terms of element count
template <class GmemTiledCopy, class Element>
constexpr int
get_alignment_count_from_gmem_tiled_copy() {
if constexpr (cute::is_void_v<GmemTiledCopy>) {
return 1;
}
// Account for ElementC = void kernels
else if constexpr (cute::is_void_v<Element>) {
return 0;
}
else {
// For TMA tiled copies, we know the alignment has to be 128 bits
if constexpr (is_tma_copy_engine<GmemTiledCopy>()) {
return 128 / sizeof_bits<Element>::value;
}
else {
// For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN
return GmemTiledCopy::NumValSrc;
}
}
}
// Return the shape that is associated with stride-1 mode, or 1 if not found
template<typename Shape, typename Stride>
CUTLASS_HOST_DEVICE constexpr
auto
get_contiguous_shape(Shape const & shape, Stride const & stride) {
using namespace cute;
auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; });
return get<decltype(idx)::value>(append(flatten(shape), _1{}));
}
// Check if tensor shape satisfies a given major alignment
template<int Alignment, class Shape, class Stride>
CUTLASS_HOST_DEVICE constexpr
bool
check_alignment(Shape const & shape, Stride const & stride) {
return is_major<0>(stride)
? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0
: get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0;
}
// Check if tensor shape satisfies a given major alignment
template<int B, int M, int S>
CUTLASS_HOST_DEVICE constexpr
size_t
alignment_for_swizzle(cute::Swizzle<B, M, S>) {
static_assert(B >= 0 and M >= 0);
return size_t(1) << size_t(B + M + cute::abs(S));
}
template<class Layout>
CUTLASS_HOST_DEVICE constexpr
size_t
alignment_for_swizzle(Layout layout) {
return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::detail