/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include namespace cute { // Generic copy_unpack for any Copy_Traits template CUTE_HOST_DEVICE constexpr void copy_unpack(Copy_Traits const&, Tensor const& src, Tensor & dst) { // Specializations can generalize on these checks //static_assert(is_smem::value, "Expected smem for this Copy_Traits"); //static_assert(is_rmem::value, "Expected rmem for this Copy_Traits"); using RegistersSrc = typename Operation::SRegisters; using RegistersDst = typename Operation::DRegisters; using RegTypeSrc = typename std::remove_extent::type; using RegTypeDst = typename std::remove_extent::type; constexpr int RegNumSrc = std::extent::value; constexpr int RegNumDst = std::extent::value; Tensor rS = recast(src); Tensor rD = recast(dst); CUTE_STATIC_ASSERT_V(size(rS) == Int{}, "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); CUTE_STATIC_ASSERT_V(size(rD) == Int{}, "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); detail::explode(Operation::copy, rS, make_int_sequence{}, rD, make_int_sequence{}); } template struct Copy_Atom; template struct Copy_Atom : Copy_Atom, T> {}; template struct Copy_Atom, T> : Copy_Traits { using Traits = Copy_Traits; // Bit and Thr layouts from the Copy_Traits using ThrID = typename Traits::ThrID; using BitLayoutSrc = typename Traits::SrcLayout; using BitLayoutDst = typename Traits::DstLayout; using BitLayoutRef = typename Traits::RefLayout; using ValType = T; using ValLayoutSrc = decltype(upcast::value>(BitLayoutSrc{})); using ValLayoutDst = decltype(upcast::value>(BitLayoutDst{})); using ValLayoutRef = decltype(upcast::value>(BitLayoutRef{})); CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); static constexpr int NumValDst = size<1>(ValLayoutDst{}); // Additional Trait parameters/transformations template CUTE_HOST_DEVICE auto with(TraitsArgs&&... args) const { auto traits = Traits::with(std::forward(args)...); return Copy_Atom{traits}; } // Print thread and data layouts for debugging CUTE_HOST_DEVICE static void print_all() { print("ThrID: "); print(ThrID{}); print("\n"); print("BitLayoutSrc: "); print(BitLayoutSrc{}); print("\n"); print("BitLayoutDst: "); print(BitLayoutDst{}); print("\n"); print("BitLayoutRef: "); print(BitLayoutRef{}); print("\n"); print("ValLayoutSrc: "); print(ValLayoutSrc{}); print("\n"); print("ValLayoutDst: "); print(ValLayoutDst{}); print("\n"); print("ValLayoutRef: "); print(ValLayoutRef{}); print("\n"); print("ValueType: %db", sizeof_bits::value); print("\n"); } // // Tensor call interfaces // // Cast, check, and call template CUTE_HOST_DEVICE void call(Tensor const& src, Tensor & dst) const { static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); if constexpr (is_constant::value || is_constant::value) { // Dispatch to unpack for instruction return copy_unpack(*this, src, dst); } else { // Recurse if needed by peeling the tensor mode return copy(*this, tensor<0>(src), tensor<0>(dst)); } } // Accept mutable temporaries template CUTE_HOST_DEVICE void call(Tensor const& src, Tensor && dst) const { return call(src, dst); } }; // // A tiling of copy atoms // template coord [Need not be 2D...] class ShapeTile_MN> // coord space struct TiledCopy : Copy_Atom { // Layout information from the CopyAtom using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); // Layout information for the TiledCopy using Tiler_MN = ShapeTile_MN; using TiledShape_MN = decltype(shape(ShapeTile_MN{})); using TiledLayout_TV = LayoutCopy_TV; using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); // Tile a tensor or a layout from shape // (M,N,...) // to shape // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) // where // ThrV: The threads local to a COPY_ATOM Src. // ThrX: The threads tiled across COPY_ATOMs Src. // FrgV: The values local to a COPY_ATOM Src. // RestM: The values tiled in M. // RestN: The values tiled in N. template CUTE_HOST_DEVICE constexpr static auto tidfrg_S(STensor&& stensor) { return thrfrg(stensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); } // Tile a tensor or a layout from shape // (M,N,...) // to shape // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) // where // ThrV: The threads local to a COPY_ATOM Dst. // ThrX: The threads tiled across COPY_ATOMs Dst. // FrgV: The values local to a COPY_ATOM Dst. // RestM: The values tiled in M. // RestN: The values tiled in N. template CUTE_HOST_DEVICE constexpr static auto tidfrg_D(DTensor&& dtensor) { return thrfrg(dtensor, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); } template CUTE_HOST_DEVICE constexpr static auto thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) { constexpr int R = remove_cvref_t::rank; static_assert(R >= rank_v, "Rank of tensor to be partitioned too small."); // Generalize the dimension checks for arbitrary rank //CUTE_STATIC_ASSERT_V(size<0>(stensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); //CUTE_STATIC_ASSERT_V(size<1>(stensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); // Take the thrs/vals that the atom is interested in // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) // Transform to the trg layout auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) // Transform the thrs mode from thrid to thr_idx // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) /// ================== // Tile the tensor for TiledLayout auto t_tensor = zipped_divide(tensor, Tiler_MN{}); // ((TileM,TileN,...),(RestM,RestN,...)) // Transform the tile mode auto tv_tensor = t_tensor.compose(thrval2mn, _); // ((thrid,val),(RM,RN,...)) // Unfold and return return tv_tensor(make_coord(_,_), _); } // retile_S and retile_D assume they are working with the reference layout -- they are the same template CUTE_HOST_DEVICE constexpr static auto retile(Tensor&& tensor) { constexpr int R = remove_cvref_t::rank; // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV // and that shape is what we gather from the other modes of tensor auto V = size<0>(tensor); auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(TiledShape_MN{})); // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); // (atom_vals,rest_vals) -> (v,m,n) /// ======= // Tile the tensor for TileFrg auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) // Transform the tile mode auto v_tensor = t_tensor.compose(frg_layout_v, _); // ((atom_vals,rest_vals),(1,RM,RN,...)) // Unfold and return return v_tensor(_, append(Int<0>{},_)); } CUTE_HOST_DEVICE constexpr static auto get_layoutS_MN() { // (M,N) -> (M,N) auto ref_S = make_layout(TiledShape_MN{}); // (thr_idx,val_idx) -> (M,N) auto layoutS_TV = tidfrg_S(ref_S); // (M,K) -> (thr_idx,val_idx) auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(ref_S)); // athrid = (v,m,k) -> thr_idx auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); return cute::make_tuple(layoutS_MK, thrID_S); } CUTE_HOST_DEVICE constexpr static auto get_layoutS_TV() { // (M,N) -> (M,N) auto ref_S = make_layout(TiledShape_MN{}); // (thr_idx,val_idx) -> (M,N) return tidfrg_S(ref_S)(_,_,Int<0>{}); } CUTE_HOST_DEVICE constexpr static auto get_layoutD_MN() { // (M,N) -> (M,N) auto ref_D = make_layout(TiledShape_MN{}); // (thr_idx,val_idx) -> (M,N) auto layoutD_TV = tidfrg_D(ref_D); // (M,K) -> (thr_idx,val_idx) auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(ref_D)); // athrid = (v,m,k) -> thr_idx auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); return cute::make_tuple(layoutD_MK, thrID_D); } CUTE_HOST_DEVICE constexpr static auto get_layoutD_TV() { // (M,N) -> (M,N) auto ref_D = make_layout(TiledShape_MN{}); // (thr_idx,val_idx) -> (M,N) return tidfrg_D(ref_D)(_,_,Int<0>{}); } template struct ThrCopy : Copy_Atom { ThrIdx thr_idx_; CUTE_HOST_DEVICE ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} template CUTE_HOST_DEVICE auto partition_S(STensor&& stensor) { //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), // "Expected ValType for tiling SrcTensor."); auto thr_tensor = make_tensor(std::forward(stensor).data(), tidfrg_S(stensor.layout())); return thr_tensor(thr_idx_, _, repeat>(_)); } template CUTE_HOST_DEVICE auto partition_D(DTensor&& dtensor) { //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), // "Expected ValType for tiling DstTensor."); auto thr_tensor = make_tensor(std::forward(dtensor).data(), tidfrg_D(dtensor.layout())); return thr_tensor(thr_idx_, _, repeat>(_)); } template CUTE_HOST_DEVICE static auto retile_S(STensor&& stensor) { static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), "Expected ValType for tiling SrcTensor."); return make_tensor(std::forward(stensor).data(), TiledCopy::retile(stensor.layout())); } template CUTE_HOST_DEVICE static auto retile_D(DTensor&& dtensor) { static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename Copy_Atom::ValType), "Expected ValType for tiling DstTensor."); return make_tensor(std::forward(dtensor).data(), TiledCopy::retile(dtensor.layout())); } }; template ::value)> CUTE_HOST_DEVICE static auto get_slice(ThrIdx const& thr_idx) { return ThrCopy(thr_idx); } template ::value)> CUTE_HOST_DEVICE static auto get_thread_slice(ThrIdx const& thr_idx) { return get_slice(thr_idx); } }; template CUTE_HOST_DEVICE auto make_tiled_copy_impl(Copy_Atom const& atom, LayoutCopy_TV const&, Tile const&) { return TiledCopy, LayoutCopy_TV, Tile>{atom}; } // // These tile the Copy_Atom as a whole // template CUTE_HOST_DEVICE auto make_tiled_copy_A(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { using MNK = typename TiledMMA::TiledShape_MNK; return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), make_shape(size<0>(MNK{}),size<2>(MNK{}))); } template CUTE_HOST_DEVICE auto make_tiled_copy_B(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { using MNK = typename TiledMMA::TiledShape_MNK; return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), make_shape(size<1>(MNK{}),size<2>(MNK{}))); } template CUTE_HOST_DEVICE auto make_tiled_copy_C(Copy_Atom const& copy_atom, TiledMMA const& tiled_mma) { using MNK = typename TiledMMA::TiledShape_MNK; return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), make_shape(size<0>(MNK{}),size<1>(MNK{}))); } template > CUTE_HOST_DEVICE auto make_tiled_copy(Copy_Atom const& copy_atom, ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx ValLayout const& val_layout = {}) { constexpr int R = cute::max(rank_v, rank_v); auto thr_layout_mn = append(thr_layout, Layout<_1>{}); auto val_layout_mn = append(val_layout, Layout<_1>{}); // Take the raked_products to compute the Layout_MN auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); //print("thr_layout: "); print(thr_layout_mn); print("\n"); //print("val_layout: "); print(val_layout_mn); print("\n"); //print("layout_mn : "); print(layout_mn); print("\n"); //print("layout_tv : "); print(layout_tv); print("\n"); return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); } // Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy template CUTE_HOST_DEVICE auto make_tiled_copy_S(Copy_Atom const& copy_atom, TiledCopy const& tiled_copy) { return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); } // Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy template CUTE_HOST_DEVICE auto make_tiled_copy_D(Copy_Atom const& copy_atom, TiledCopy const& tiled_copy) { return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); } // // Size // // The logical size of a TileCopy template CUTE_HOST_DEVICE constexpr auto tile_size(TiledCopy const&) { return size(typename TiledCopy::TiledShape_MN{}); } // The number of threads involved in a TiledCopy template CUTE_HOST_DEVICE constexpr auto size(TiledCopy const&) { return typename TiledCopy::TiledNumThr{}; } // // Display utilities // template CUTE_HOST_DEVICE auto print_latex(TiledCopy const& copy) { auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); print_latex_copy(layoutS_MN, thrID_S, layoutD_MN, thrID_D); } // MNK Copy Layout to Latex TIKZ -- 8-value color coded by thread template CUTE_HOST_DEVICE void print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx LayoutD const& D, ThrIDD const& TD) // (m,n) -> (tid,vid) and tid -> thr_idx { CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); assert(size<0>(S) == size<0>(D)); assert(size<1>(S) == size<1>(D)); char const* latex_header = "\\documentclass{standalone}\n" "\\usepackage{tikz}\n" "\\usetikzlibrary{external}\n" "\\tikzexternalize\n" "\\begin{document}\n" "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},box/.style={rectangle,draw=black,thick,minimum size=1cm,anchor=center}]\n\n"; char const* latex_footer = "\\end{tikzpicture}\n" "\\end{document}\n"; char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", "{rgb,255:red,175;green,255;blue,175}", "{rgb,255:red,255;green,255;blue,175}", "{rgb,255:red,255;green,175;blue,175}", "{rgb,255:red,210;green,210;blue,255}", "{rgb,255:red,210;green,255;blue,210}", "{rgb,255:red,255;green,255;blue,210}", "{rgb,255:red,255;green,210;blue,210}",}; // Header printf("%% LayoutS: "); print(S); printf("\n"); printf("%% ThrIDS : "); print(TS); printf("\n"); printf("%% LayoutD: "); print(D); printf("\n"); printf("%% ThrIDD : "); print(TD); printf("\n\n"); printf(latex_header); // S starting at 0,0 for (int i = 0; i < size<0>(S); ++i) { for (int j = 0; j < size<1>(S); ++j) { int thrid = S(i,j) % size(TS); int val_idx = S(i,j) / size(TS); int thr_idx = TS(thrid); printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", color_map[thr_idx % 8], i, j, thr_idx, val_idx); } } // D starting at 0,size<1>(S)+3 for (int i = 0; i < size<0>(D); ++i) { for (int j = 0; j < size<1>(D); ++j) { int thrid = D(i,j) % size(TD); int val_idx = D(i,j) / size(TD); int thr_idx = TD(thrid); printf("\\node[box,fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", color_map[thr_idx % 8], i, j + size<1>(S) + 3, thr_idx, val_idx); } } // S Labels for (int i = 0, j = -1; i < size<0>(S); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); } for (int j = 0, i = -1; j < size<1>(S); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); } // D Labels for (int i = 0, j = size<1>(D); i < size<0>(S); ++i) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); } for (int j = 0, i = -1; j < size<1>(D); ++j) { printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); } // Footer printf(latex_footer); } } // end namespace cute //////////////////////////////////////////////////////////////////////////////////////////////////// #include #include #include #include // Config #if (__CUDACC_VER_MAJOR__ >= 12) # define CUTE_COPY_ATOM_TMA_SM90_ENABLED #endif #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) #include #endif ////////////////////////////////////////////////////////////////////////////////////////////////////