/*************************************************************************************************** * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cute/atom/mma_atom.hpp" #include "cute/atom/copy_atom.hpp" #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/arch/arch.h" #include "cutlass/arch/mma.h" #include "cutlass/layout/layout.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/collective/collective_mma.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" namespace cutlass { namespace gemm { namespace device { using namespace cute; // This type is only intended to demonstrate porting 2.x kernels to 3.0 template< class OperatorClass, class ArchTag, class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types { static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); }; /////////////////////////////////////////////////////////////////////////////// namespace detail { template struct DefaultGemm_TensorOpSm80_OperandA; template struct DefaultGemm_TensorOpSm80_OperandB; // // F16: 128-by-128-by-64 // /// Operand A - Row-major (K-Major) template <> struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<3,3,3>{}, Layout, Stride<_64, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, Stride< _8,_1>>{}, Layout>{})); }; /// Operand A - Column-major (M-major) template struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<3,3,3>{}, Layout, Stride< _1,_64>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, Stride< _1,_16>>{}, Layout>{})); }; // Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands // Operand B - Column-Major (K-major) template struct DefaultGemm_TensorOpSm80_OperandB : DefaultGemm_TensorOpSm80_OperandA {}; // Operand B - Row-Major (N-major) template struct DefaultGemm_TensorOpSm80_OperandB : DefaultGemm_TensorOpSm80_OperandA {}; // // F16: 128-by-128-by-32 (small k-block) // /// Operand A - Row-major (K-Major) template <> struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<2,3,3>{}, Layout, Stride<_32, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, half_t>{}, Layout, Stride< _4,_1>>{}, Layout>{})); }; } /////////////////////////////////////////////////////////////////////////////// // Ampere MMA F32F16 template struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, half_t, LayoutA, half_t, LayoutB, float, LayoutC, float> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Layout>>; // 1x2x1 value group for 16x16x16 MMA and LDSM // A static constexpr int kAlignmentA = 8; using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< half_t, LayoutA, kAlignmentA, 32>; using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; // B static constexpr int kAlignmentB = 8; using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< half_t, LayoutB, kAlignmentB, 32>; using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, half_t, TagToStrideA_t, half_t, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// namespace detail { // // TF32: 128-by-128-by-kblock (kBlock = 16, 32) // /// Operand A - Row-major (K-major) (kBlock = 32) template <> struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<3,2,3>{}, Layout, Stride<_32, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, tfloat32_t>{}, Layout, Stride< _8,_1>>{}, Layout>{})); }; /// Operand A - Row-major (K-major) (kBlock = 16) template <> struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<2,2,3>{}, Layout, Stride<_16, _1>>{})); using SmemCopyAtom = Copy_Atom; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, tfloat32_t>{}, Layout, Stride< _4,_1>>{}, Layout>{})); }; /// Operand A - Column-major (M-major) template struct DefaultGemm_TensorOpSm80_OperandA { // Smem using SmemLayoutAtom = decltype( composition(Swizzle<3,2,3>{}, Layout, Stride< _1,_32>>{})); using SmemCopyAtom = Copy_Atom, tfloat32_t>; // Gmem using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, tfloat32_t>{}, Layout, Stride< _1,_16>>{}, Layout>{})); }; // Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands // Operand B - Column-Major (K-major) template struct DefaultGemm_TensorOpSm80_OperandB : DefaultGemm_TensorOpSm80_OperandA {}; // Operand B - Row-Major (N-major) template struct DefaultGemm_TensorOpSm80_OperandB : DefaultGemm_TensorOpSm80_OperandA {}; } /////////////////////////////////////////////////////////////////////////////// // Ampere MMA F32TF32 template struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, tfloat32_t, LayoutA, tfloat32_t, LayoutB, float, LayoutC, float> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group Layout>>; // 1x2x1 value group for 16x16x8 and LDSM // A static constexpr int kAlignmentA = 4; using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< tfloat32_t, LayoutA, kAlignmentA, 32>; using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; // B static constexpr int kAlignmentB = 4; using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< tfloat32_t, LayoutB, kAlignmentB, 32>; using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, tfloat32_t, TagToStrideA_t, tfloat32_t, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// template struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, int32_t, LayoutC, int32_t> { using TileShape = Shape<_128, _128, _64>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, Layout>, // 2x2x1 thread group Layout>>; // 1x2x1 value group for 16x16x32 and LDSM // A (M,K) K-major using SmemLayoutAtomA = decltype( composition( Swizzle<2,4,3>{}, Layout, Stride<_64, _1>>{})); static constexpr int kAlignmentA = 16; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, int8_t>{}, Layout, Stride< _4,_1>>{}, Layout>>{})); // LDS.32- or LDSM-based copy atom // using SmemCopyAtomA = Copy_Atom; using SmemCopyAtomA = Copy_Atom; // LDSM works // B (N,K) K-major using SmemLayoutAtomB = decltype( composition( Swizzle<2,4,3>{}, Layout, Stride<_64, _1>>{})); static constexpr int kAlignmentB = 16; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, int8_t>{}, Layout, Stride< _4,_1>>{}, Layout>>{})); // LDS.32- or LDSM-based copy atom // using SmemCopyAtomB = Copy_Atom; using SmemCopyAtomB = Copy_Atom; // LDSM works // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, int8_t, TagToStrideA_t, int8_t, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// //////////////////////////// SIMT TWO STAGE /////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// namespace detail { template struct DefaultGemm_Simt_OperandA; /////////////////////////////////////////////////////////////////////////////// template struct DefaultGemm_Simt_OperandA { using SmemLayoutAtom = Layout, Stride< _1,_128>>; using SmemCopyAtom = Copy_Atom; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, Layout, Stride< _1,_32>>{}, Layout>{})); }; template struct DefaultGemm_Simt_OperandA { using SmemLayoutAtom = Layout, Stride< _1,Int<128 + 4>>>; // Padded using SmemCopyAtom = Copy_Atom; using GmemTiledCopy = decltype( make_tiled_copy(Copy_Atom, Element>{}, Layout, Stride< _8, _1>>{}, Layout>{})); }; template struct DefaultGemm_Simt_OperandB; template struct DefaultGemm_Simt_OperandB : DefaultGemm_Simt_OperandA {}; template struct DefaultGemm_Simt_OperandB : DefaultGemm_Simt_OperandA {}; } // end namespace detail // SIMT Two Stage template < class ArchTag, class ElementA, class LayoutA, class ElementB, class LayoutB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, ArchTag, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementAccumulator> { using TileShape = Shape<_128, _128, _8>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm70TwoStage; using TiledMma = TiledMMA< MMA_Atom>, Layout>>; // A static constexpr int kAlignmentA = 1; using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; // B static constexpr int kAlignmentB = 1; using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; // // DP4A - int8 Proof-of-concept // // SIMT Two Stage TN - idp4a template < class ArchTag, class ElementC, class LayoutC> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, ArchTag, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::ColumnMajor, ElementC, LayoutC, int32_t> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm70TwoStage; // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts using TiledMma = TiledMMA< MMA_Atom, Layout>>; // Tile of atoms (threads) // A (M,K) K-major using ElementA = int8_t; // 40% from regular M and N major layout // using SmemLayoutAtomA = Layout, // Stride< _1,_128>>; // 80% from interleaved layouts using SmemLayoutAtomA = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 4; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // B (N,K) K-major using ElementB = int8_t; // 40% from regular M and N major layout // using SmemLayoutAtomB = Layout, // Stride< _1,_128>>; // 80% from interleaved layouts using SmemLayoutAtomB = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 4; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Two Stage NN - idp4a template < class ArchTag, class ElementC, class LayoutC> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, ArchTag, int8_t, cutlass::layout::ColumnMajor, int8_t, cutlass::layout::ColumnMajor, ElementC, LayoutC, int32_t> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm70TwoStage; using TiledMma = TiledMMA< MMA_Atom, Layout>>; // A (M,K) M-major using ElementA = int8_t; using SmemLayoutAtomA = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride< _1,_32>>{}, Layout>{})); // B (N,K) K-major using ElementB = int8_t; using SmemLayoutAtomB = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 4; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Two Stage NT - idp4a template < class ArchTag, class ElementC, class LayoutC> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, ArchTag, int8_t, cutlass::layout::ColumnMajor, int8_t, cutlass::layout::RowMajor, ElementC, LayoutC, int32_t> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm70TwoStage; using TiledMma = TiledMMA< MMA_Atom, Layout>>; // A (M,K) M-major using ElementA = int8_t; using SmemLayoutAtomA = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride< _1,_32>>{}, Layout>{})); // B (N,K) N-major using ElementB = int8_t; using SmemLayoutAtomB = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride< _1,_32>>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Two Stage TT - idp4a template < class ArchTag, class ElementC, class LayoutC> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, ArchTag, int8_t, cutlass::layout::RowMajor, int8_t, cutlass::layout::RowMajor, ElementC, LayoutC, int32_t> { using TileShape = Shape<_128, _128, _32>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm70TwoStage; using TiledMma = TiledMMA< MMA_Atom, Layout>>; // A (M,K) K-major using ElementA = int8_t; using SmemLayoutAtomA = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 4; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // B (N,K) N-major using ElementB = int8_t; using SmemLayoutAtomB = Layout>, Stride< _4, Stride<_1,_512>>>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride< _1,_32>>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// /////////////////////////// SIMT MULTI STAGE ////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// // SIMT Multi Stage NT template < class ElementA, class ElementB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, arch::Sm80, ElementA, cutlass::layout::ColumnMajor, ElementB, cutlass::layout::RowMajor, ElementC, LayoutC, ElementAccumulator> { using TileShape = Shape<_128, _128, _16>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, Layout>, Layout>, Tile,Layout<_2,_16>,Underscore>>; // A (M,K) M-major using SmemLayoutAtomA = Layout>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 2; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout>{}, Layout>{})); // B (N,K) N-major using SmemLayoutAtomB = Layout>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 2; using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Multi Stage TN template < class ElementA, class ElementB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, arch::Sm80, ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::ColumnMajor, ElementC, LayoutC, ElementAccumulator> { using TileShape = Shape<_128, _128, _16>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, Layout>>; // A (M,K) K-major using SmemLayoutAtomA = Layout, Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride<_16, _1>>{})); // B (N,K) K-major using SmemLayoutAtomB = Layout, Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride<_16, _1>>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Multi Stage NN template < class ElementA, class ElementB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, arch::Sm80, ElementA, cutlass::layout::ColumnMajor, ElementB, cutlass::layout::ColumnMajor, ElementC, LayoutC, ElementAccumulator> { using TileShape = Shape<_128, _128, _16>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, Layout>, Layout>, Tile,Underscore,Underscore>>; // A (M,K) M-major using SmemLayoutAtomA = Layout>; using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 2; using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout>{}, Layout>{})); // B (N,K) K-major using SmemLayoutAtomB = Layout, Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout, Stride<_16, _1>>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // SIMT Multi Stage TT template < class ElementA, class ElementB, class ElementC, class LayoutC, class ElementAccumulator> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassSimt, arch::Sm80, ElementA, cutlass::layout::RowMajor, ElementB, cutlass::layout::RowMajor, ElementC, LayoutC, ElementAccumulator> { using TileShape = Shape<_128, _128, _16>; static constexpr int ThreadCount = 256; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom>, Layout>, Layout>, Tile,Underscore>>; // A (M,K) K-major using SmemLayoutAtomA = Layout, Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, ElementA>{}, Layout, Stride<_16, _1>>{})); // B (N,K) N-major using SmemLayoutAtomB = Layout>; using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 2; using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, ElementB>{}, Layout>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, ElementA, TagToStrideA_t, ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // Ampere fp64 MMA TN (K-Major A and K-Major B) template <> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, double, cutlass::layout::RowMajor, double, cutlass::layout::ColumnMajor, double, cutlass::layout::ColumnMajor, double> { using TileShape = Shape<_128, _64, _16>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout Layout>, // Val layout Tile,Layout<_2,_16>,Underscore>>; // Mode permutations // A (M,K) K-Major using SmemLayoutAtomA = decltype( composition(SwizzleXor<2,0,2>{}, Layout, Stride<_1, _4>>{})); // M, K using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride<_16, _1>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 1x1 doubles // B (N,K) K-Major using SmemLayoutAtomB = decltype( composition(SwizzleXor<2,0,2>{}, Layout, Stride<_1, _4>>{})); // N, K using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride<_16, _1>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 1x1 doubles // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, double, TagToStrideA_t, double, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; /* using EpilogueOutputOp = epilogue::collective::Epilogue< epilogue::thread::LinearCombination, Layout, Stride< _1,_64>>, // SMEM layout Copy_Atom,double>, // R2S with tiled_mma layout decltype(make_tiled_copy(Copy_Atom,double>{},// S2R Layout, Stride< _1,_16>>{}, // Thread layout Layout>{})), // Value layout Copy_Atom,double> // R2G with S2R_dst layout >; */ }; /////////////////////////////////////////////////////////////////////////////// // Ampere fp64 MMA NN (M-Major A and K-Major B) template <> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, double, cutlass::layout::ColumnMajor, double, cutlass::layout::ColumnMajor, double, cutlass::layout::ColumnMajor, double> { using TileShape = Shape<_128, _64, _16>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout Layout>, // Val layout Tile,Layout<_2,_16>,Underscore>>; // Mode permutations // A (M,K) M-Major using SmemLayoutAtomA = decltype( composition(SwizzleXor<2,2,0>{}, Layout, Stride< _1,_16>>{})); // M, K using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 2; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride< _1,_16>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 2x1 doubles // B (N,K) K-Major using SmemLayoutAtomB = decltype( composition(SwizzleXor<2,0,2>{}, Layout, Stride<_1, _4>>{}));// N, K using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 1; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride<_16, _1>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 1x1 doubles // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, double, TagToStrideA_t, double, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // Ampere fp64 MMA NT (M-Major A and N-Major B) template <> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, double, cutlass::layout::ColumnMajor, double, cutlass::layout::RowMajor, double, cutlass::layout::ColumnMajor, double> { using TileShape = Shape<_128, _64, _16>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout Layout>, // Val layout Tile,Layout<_2,_16>,Underscore>>; // Mode permutations // A (M,K) M-Major using SmemLayoutAtomA = decltype( composition(SwizzleXor<2,2,0>{}, Layout, Stride< _1,_16>>{})); // M, K using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 2; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride< _1,_16>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 2x1 doubles // B (N,K) N-Major using SmemLayoutAtomB = decltype( composition(SwizzleXor<2,2,0>{}, Layout, Stride< _1,_16>>{})); // N, K using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 2; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride< _1,_16>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 2x1 doubles // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, double, TagToStrideA_t, double, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // Ampere fp64 MMA TT (K-Major A and N-Major B) template <> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm80, double, cutlass::layout::RowMajor, double, cutlass::layout::RowMajor, double, cutlass::layout::ColumnMajor, double> { using TileShape = Shape<_128, _64, _16>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, // Atom Layout>, // Atom layout Layout>, // Val layout Tile,Layout<_2,_16>,Underscore>>; // Mode permutations // A (M,K) K-Major using SmemLayoutAtomA = decltype( composition(SwizzleXor<2,0,2>{}, Layout, Stride<_1, _4>>{})); // M, K using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 1; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride<_16, _1>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 1x1 doubles // B (N,K) N-Major using SmemLayoutAtomB = decltype( composition(SwizzleXor<2,2,0>{}, Layout, Stride< _1,_16>>{})); // N, K using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 2; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, double>{}, // CopyAtom Layout, Stride< _1,_16>>{}, // ThrLayout for CopyAtom Layout>{})); // Value layout: 2x1 doubles // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, double, TagToStrideA_t, double, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< TagToStrideC_t, TagToStrideC_t, epilogue::thread::LinearCombination, cutlass::gemm::EpilogueDefault>; }; /////////////////////////////////////////////////////////////////////////////// // Hopper fp64 MMA TN template <> struct DefaultGemmConfigurationToCutlass3Types< arch::OpClassTensorOp, arch::Sm90, double, cutlass::layout::RowMajor, double, cutlass::layout::ColumnMajor, double, cutlass::layout::ColumnMajor, double> { using TileShape = Shape<_128, _64, _16>; static constexpr int ThreadCount = 128; using DispatchPolicy = MainloopSm80CpAsync<3>; using TiledMma = TiledMMA< MMA_Atom, Layout>>; // A (M,K) K-major using SmemLayoutAtomA = decltype( make_ordered_layout(Shape<_128,_16>{}, Step < _2, _1>{})); // M, K using SmemCopyAtomA = Copy_Atom; static constexpr int kAlignmentA = 2; using GmemTiledCopyA = decltype( make_tiled_copy(Copy_Atom, double>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // B (N,K) K-major using SmemLayoutAtomB = decltype( make_ordered_layout(Shape<_64,_16>{}, Step < _2, _1>{})); // N, K using SmemCopyAtomB = Copy_Atom; static constexpr int kAlignmentB = 2; using GmemTiledCopyB = decltype( make_tiled_copy(Copy_Atom, double>{}, Layout, Stride< _8,_1>>{}, Layout>{})); // Mainloop using CollectiveMainloop = collective::CollectiveMma< DispatchPolicy, TileShape, double, TagToStrideA_t, double, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B >; // Epilogue using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, double, double, double, cutlass::layout::ColumnMajor, 1, double, cutlass::layout::ColumnMajor, 1, cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; }; /////////////////////////////////////////////////////////////////////////////// } // namespace device } // namespace gemm } // namespace cutlass