| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | /***************************************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | 
					
						
							|  |  |  |  * SPDX-License-Identifier: BSD-3-Clause | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Redistribution and use in source and binary forms, with or without | 
					
						
							|  |  |  |  * modification, are permitted provided that the following conditions are met: | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 1. Redistributions of source code must retain the above copyright notice, this | 
					
						
							|  |  |  |  * list of conditions and the following disclaimer. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 2. Redistributions in binary form must reproduce the above copyright notice, | 
					
						
							|  |  |  |  * this list of conditions and the following disclaimer in the documentation | 
					
						
							|  |  |  |  * and/or other materials provided with the distribution. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * 3. Neither the name of the copyright holder nor the names of its | 
					
						
							|  |  |  |  * contributors may be used to endorse or promote products derived from | 
					
						
							|  |  |  |  * this software without specific prior written permission. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | 
					
						
							|  |  |  |  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | 
					
						
							|  |  |  |  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | 
					
						
							|  |  |  |  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | 
					
						
							|  |  |  |  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | 
					
						
							|  |  |  |  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | 
					
						
							|  |  |  |  * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | 
					
						
							|  |  |  |  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | 
					
						
							|  |  |  |  * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | 
					
						
							|  |  |  |  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  **************************************************************************************************/ | 
					
						
							|  |  |  | #pragma once
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cute/atom/mma_atom.hpp"
 | 
					
						
							|  |  |  | #include "cute/atom/copy_atom.hpp"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/cutlass.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/gemm.h"
 | 
					
						
							|  |  |  | #include "cutlass/arch/arch.h"
 | 
					
						
							|  |  |  | #include "cutlass/arch/mma.h"
 | 
					
						
							|  |  |  | #include "cutlass/layout/layout.h"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/dispatch_policy.hpp"
 | 
					
						
							|  |  |  | #include "cutlass/gemm/collective/collective_mma.hpp"
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  | #include "cutlass/epilogue/collective/collective_builder.hpp"
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/collective/default_epilogue.hpp"
 | 
					
						
							|  |  |  | #include "cutlass/epilogue/thread/linear_combination.h"
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace cutlass { | 
					
						
							|  |  |  | namespace gemm { | 
					
						
							|  |  |  | namespace device { | 
					
						
							|  |  |  | using namespace cute; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // This type is only intended to demonstrate porting 2.x kernels to 3.0
 | 
					
						
							|  |  |  | template< | 
					
						
							|  |  |  |   class OperatorClass, class ArchTag, | 
					
						
							|  |  |  |   class ElementA, class LayoutA, | 
					
						
							|  |  |  |   class ElementB, class LayoutB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types { | 
					
						
							|  |  |  |   static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, typename Layout, int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, typename Layout, int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandB; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // F16: 128-by-128-by-64
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Row-major (K-Major)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor, 8, 64> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<3,3,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape < _8,_64>, | 
					
						
							|  |  |  |                        Stride<_64, _1>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_8>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Column-major (M-major)
 | 
					
						
							|  |  |  | template <int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::ColumnMajor, 8, SizeK> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<3,3,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape <_64, _8>, | 
					
						
							|  |  |  |                        Stride< _1,_64>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, half_t>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _8, _1>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Operand B - Column-Major (K-major)
 | 
					
						
							|  |  |  | template <int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandB<half_t, layout::ColumnMajor, Alignment, SizeK> | 
					
						
							|  |  |  |      : DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor,    Alignment, SizeK> | 
					
						
							|  |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Operand B - Row-Major (N-major)
 | 
					
						
							|  |  |  | template <int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandB<half_t, layout::RowMajor,    Alignment, SizeK> | 
					
						
							|  |  |  |      : DefaultGemm_TensorOpSm80_OperandA<half_t, layout::ColumnMajor, Alignment, SizeK> | 
					
						
							|  |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // F16: 128-by-128-by-32 (small k-block)
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Row-major (K-Major)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor, 8, 32> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<2,3,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape < _8,_32>, | 
					
						
							|  |  |  |                        Stride<_32, _1>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_4>, | 
					
						
							|  |  |  |                            Stride< _4,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_8>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere MMA F32F16
 | 
					
						
							|  |  |  | template <typename LayoutA, typename LayoutB, typename LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     half_t, LayoutA, | 
					
						
							|  |  |  |     half_t, LayoutB, | 
					
						
							|  |  |  |     float, LayoutC, | 
					
						
							|  |  |  |     float> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,  // 2x2x1 thread group
 | 
					
						
							|  |  |  |       Layout<Shape<_1,_2,_1>>>; // 1x2x1 value group for 16x16x16 MMA and LDSM
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 8; | 
					
						
							|  |  |  |   using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< | 
					
						
							|  |  |  |     half_t, LayoutA, kAlignmentA, 32>; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 8; | 
					
						
							|  |  |  |   using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< | 
					
						
							|  |  |  |     half_t, LayoutB, kAlignmentB, 32>; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     half_t, TagToStrideA_t<LayoutA>, | 
					
						
							|  |  |  |     half_t, TagToStrideB_t<LayoutB>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<float, 1, float, float>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // TF32: 128-by-128-by-kblock (kBlock = 16, 32)
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Row-major  (K-major) (kBlock = 32)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor, 4, 32> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<3,2,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape < _8,_32>, | 
					
						
							|  |  |  |                        Stride<_32, _1>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Row-major  (K-major) (kBlock = 16)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor, 4, 16> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<2,2,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape < _8,_16>, | 
					
						
							|  |  |  |                        Stride<_16, _1>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>; | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_4>, | 
					
						
							|  |  |  |                            Stride< _4,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | /// Operand A - Column-major  (M-major)
 | 
					
						
							|  |  |  | template <int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::ColumnMajor, 4, SizeK> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   // Smem
 | 
					
						
							|  |  |  |   using SmemLayoutAtom = decltype( | 
					
						
							|  |  |  |     composition(Swizzle<3,2,3>{}, | 
					
						
							|  |  |  |                 Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                        Stride< _1,_32>>{})); | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<UniversalCopy<tfloat32_t>, tfloat32_t>; | 
					
						
							|  |  |  |   // Gmem
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _4, _1>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Operand B - Column-Major  (K-major)
 | 
					
						
							|  |  |  | template <int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, layout::ColumnMajor, Alignment, SizeK> | 
					
						
							|  |  |  |      : DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor,    Alignment, SizeK> | 
					
						
							|  |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Operand B - Row-Major  (N-major)
 | 
					
						
							|  |  |  | template <int Alignment, int SizeK> | 
					
						
							|  |  |  | struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, layout::RowMajor,    Alignment, SizeK> | 
					
						
							|  |  |  |      : DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::ColumnMajor, Alignment, SizeK> | 
					
						
							|  |  |  | {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere MMA F32TF32
 | 
					
						
							|  |  |  | template <typename LayoutA, typename LayoutB, typename LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     tfloat32_t, LayoutA, | 
					
						
							|  |  |  |     tfloat32_t, LayoutB, | 
					
						
							|  |  |  |     float, LayoutC, | 
					
						
							|  |  |  |     float> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>, | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>, Stride<_2, _1, _1>>, // 2x2x1 thread group
 | 
					
						
							|  |  |  |       Layout<Shape<_1,_2,_1>>>;                    // 1x2x1 value group for 16x16x8 and LDSM
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 4; | 
					
						
							|  |  |  |   using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< | 
					
						
							|  |  |  |     tfloat32_t, LayoutA, kAlignmentA, 32>; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 4; | 
					
						
							|  |  |  |   using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< | 
					
						
							|  |  |  |     tfloat32_t, LayoutB, kAlignmentB, 32>; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     tfloat32_t, TagToStrideA_t<LayoutA>, | 
					
						
							|  |  |  |     tfloat32_t, TagToStrideB_t<LayoutB>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<float, 1, float, float>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | template <typename LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     int32_t, LayoutC, | 
					
						
							|  |  |  |     int32_t> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _64>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>, | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,   // 2x2x1 thread group
 | 
					
						
							|  |  |  |       Layout<Shape<_1,_2,_1>>>;  // 1x2x1 value group for 16x16x32 and LDSM
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							|  |  |  |     composition( | 
					
						
							|  |  |  |       Swizzle<2,4,3>{}, | 
					
						
							|  |  |  |       Layout<Shape <_16,_64>, | 
					
						
							|  |  |  |              Stride<_64, _1>>{})); | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 16; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, int8_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_4>, | 
					
						
							|  |  |  |                            Stride< _4,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_1,Int<kAlignmentA>>>{})); | 
					
						
							|  |  |  |   // LDS.32- or LDSM-based copy atom
 | 
					
						
							|  |  |  |   // using SmemCopyAtomA = Copy_Atom<DefaultCopy, uint8_t>;
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<SM75_U32x4_LDSM_N, uint8_t>;  // LDSM works
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							|  |  |  |     composition( | 
					
						
							|  |  |  |       Swizzle<2,4,3>{}, | 
					
						
							|  |  |  |       Layout<Shape <_16,_64>, | 
					
						
							|  |  |  |              Stride<_64, _1>>{})); | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 16; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, int8_t>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_4>, | 
					
						
							|  |  |  |                            Stride< _4,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_1,Int<kAlignmentB>>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // LDS.32- or LDSM-based copy atom
 | 
					
						
							|  |  |  |   // using SmemCopyAtomB = Copy_Atom<DefaultCopy, uint32_t>;
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<SM75_U32x4_LDSM_N, uint8_t>;  // LDSM works
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     int8_t, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     int8_t, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | //////////////////////////// SIMT TWO STAGE ///////////////////////////////////
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | namespace detail { | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, typename Layout, int ThreadCount, int ShapeM, int ShapeK> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandA; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandA<Element, layout::ColumnMajor, 256, 128, 8> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using SmemLayoutAtom = Layout<Shape <_128,  _8>, | 
					
						
							|  |  |  |                                 Stride<  _1,_128>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_32>>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{})); | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandA<Element, layout::RowMajor, 256, 128, 8> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using SmemLayoutAtom = Layout<Shape <_128,          _8>, | 
					
						
							|  |  |  |                                 Stride<  _1,Int<128 + 4>>>;   // Padded
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using GmemTiledCopy = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _8, _1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, typename Layout, int ThreadCount, int ShapeN, int ShapeK> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandB; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, int ThreadCount, int ShapeN, int ShapeK> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandB<Element, layout::ColumnMajor, ThreadCount, ShapeN, ShapeK> | 
					
						
							|  |  |  |      : DefaultGemm_Simt_OperandA<Element, layout::RowMajor,    ThreadCount, ShapeN, ShapeK> {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | template <typename Element, int ThreadCount, int ShapeN, int ShapeK> | 
					
						
							|  |  |  | struct DefaultGemm_Simt_OperandB<Element, layout::RowMajor,    ThreadCount, ShapeN, ShapeK> | 
					
						
							|  |  |  |      : DefaultGemm_Simt_OperandA<Element, layout::ColumnMajor, ThreadCount, ShapeN, ShapeK> {}; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // end namespace detail
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Two Stage
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ArchTag, | 
					
						
							|  |  |  |   class ElementA, class LayoutA, | 
					
						
							|  |  |  |   class ElementB, class LayoutB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, ArchTag, | 
					
						
							|  |  |  |     ElementA, LayoutA, | 
					
						
							|  |  |  |     ElementB, LayoutB, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _8>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm70TwoStage; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using DefaultOperandA = detail::DefaultGemm_Simt_OperandA<ElementA, LayoutA, ThreadCount, 128, 8>; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; | 
					
						
							|  |  |  |   using SmemCopyAtomA   = typename DefaultOperandA::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyA  = typename DefaultOperandA::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B
 | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using DefaultOperandB = detail::DefaultGemm_Simt_OperandB<ElementB, LayoutB, ThreadCount, 128, 8>; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; | 
					
						
							|  |  |  |   using SmemCopyAtomB   = typename DefaultOperandB::SmemCopyAtom; | 
					
						
							|  |  |  |   using GmemTiledCopyB  = typename DefaultOperandB::GmemTiledCopy; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<LayoutA>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<LayoutB>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // DP4A - int8    Proof-of-concept
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Two Stage TN - idp4a
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ArchTag, | 
					
						
							|  |  |  |   class ElementC, class LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, ArchTag, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     int32_t> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm70TwoStage; | 
					
						
							|  |  |  |   // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts
 | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM61_DP4A>, | 
					
						
							|  |  |  |       Layout<Shape<_16,_16,_1>>>;  // Tile of atoms (threads)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using ElementA = int8_t; | 
					
						
							|  |  |  |   // 40% from regular M and N major layout
 | 
					
						
							|  |  |  |   // using SmemLayoutAtomA = Layout<Shape <_128,_32>,
 | 
					
						
							|  |  |  |   //                                Stride<  _1,_128>>;
 | 
					
						
							|  |  |  |   // 80% from interleaved layouts
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 4; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using ElementB = int8_t; | 
					
						
							|  |  |  |   // 40% from regular M and N major layout
 | 
					
						
							|  |  |  |   // using SmemLayoutAtomB = Layout<Shape <_128,_32>,
 | 
					
						
							|  |  |  |   //                                Stride<  _1,_128>>;
 | 
					
						
							|  |  |  |   // 80% from interleaved layouts
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 4; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Two Stage NN - idp4a
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ArchTag, | 
					
						
							|  |  |  |   class ElementC, class LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, ArchTag, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     int32_t> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm70TwoStage; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM61_DP4A>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  M-major
 | 
					
						
							|  |  |  |   using ElementA = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_32>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using ElementB = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 4; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-10 12:22:56 +08:00
										 |  |  |   // Epilogue
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Two Stage NT - idp4a
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ArchTag, | 
					
						
							|  |  |  |   class ElementC, class LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, ArchTag, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     int32_t> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm70TwoStage; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM61_DP4A>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  M-major
 | 
					
						
							|  |  |  |   using ElementA = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_32>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  N-major
 | 
					
						
							|  |  |  |   using ElementB = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_32>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Two Stage TT - idp4a
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ArchTag, | 
					
						
							|  |  |  |   class ElementC, class LayoutC> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, ArchTag, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     int8_t, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     int32_t> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _32>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm70TwoStage; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM61_DP4A>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using ElementA = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 4; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_4>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  N-major
 | 
					
						
							|  |  |  |   using ElementB = int8_t; | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4,  _8>>, | 
					
						
							|  |  |  |                                  Stride<  _4, Stride<_1,_512>>>; | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_32, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_32>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | /////////////////////////// SIMT MULTI STAGE //////////////////////////////////
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Multi Stage NT
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ElementA, | 
					
						
							|  |  |  |   class ElementB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, arch::Sm80, | 
					
						
							|  |  |  |     ElementA, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementB, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>, | 
					
						
							|  |  |  |       Layout<Shape< _2,  _2, _1>>, | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Layout<_2,_16>,Underscore>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  M-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape<_128,_16>>; | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 2; | 
					
						
							|  |  |  |   using AlignmentTypeA = cute::uint_byte_t<static_cast<int>(sizeof(ElementA)) * kAlignmentA>; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeA>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_32,_8>>{}, | 
					
						
							|  |  |  |                     Layout<Shape< _2,_1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  N-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape<_128,_16>>; | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 2; | 
					
						
							|  |  |  |   using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * kAlignmentB>; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeB>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_32,_8>>{}, | 
					
						
							|  |  |  |                     Layout<Shape< _2,_1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Multi Stage TN
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ElementA, | 
					
						
							|  |  |  |   class ElementB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, arch::Sm80, | 
					
						
							|  |  |  |     ElementA, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementB, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128,          _16>, | 
					
						
							|  |  |  |                                  Stride<  _1, Int<128 + 1>>>;  // Padded by kAlignmentA
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementA>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128,          _16>, | 
					
						
							|  |  |  |                                  Stride<  _1, Int<128 + 1>>>;  // Padded by kAlignmentB
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementB>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Multi Stage NN
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ElementA, | 
					
						
							|  |  |  |   class ElementB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, arch::Sm80, | 
					
						
							|  |  |  |     ElementA, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementB, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>, | 
					
						
							|  |  |  |       Layout<Shape< _2,  _1, _1>>, | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Underscore,Underscore>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  M-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape<_128,_16>>; | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 2; | 
					
						
							|  |  |  |   using AlignmentTypeA = cute::uint_byte_t<static_cast<int>(sizeof(ElementA)) * kAlignmentA>; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeA>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_32,_8>>{}, | 
					
						
							|  |  |  |                     Layout<Shape< _2,_1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128,          _16>, | 
					
						
							|  |  |  |                                  Stride<  _1, Int<128 + 1>>>;  // Padded by kAlignmentB
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementB>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // SIMT Multi Stage TT
 | 
					
						
							|  |  |  | template < | 
					
						
							|  |  |  |   class ElementA, | 
					
						
							|  |  |  |   class ElementB, | 
					
						
							|  |  |  |   class ElementC, class LayoutC, | 
					
						
							|  |  |  |   class ElementAccumulator> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassSimt, arch::Sm80, | 
					
						
							|  |  |  |     ElementA, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementB, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     ElementC, LayoutC, | 
					
						
							|  |  |  |     ElementAccumulator> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _128, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 256; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>, | 
					
						
							|  |  |  |       Layout<Shape<_16, _16, _1>>, | 
					
						
							|  |  |  |       Layout<Shape< _1,  _2, _1>>, | 
					
						
							|  |  |  |       Tile<Underscore,Layout<_2,_16>,Underscore>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = Layout<Shape <_128,          _16>, | 
					
						
							|  |  |  |                                  Stride<  _1, Int<128 + 1>>>;  // Padded by kAlignmentA
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementA>, ElementA>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  N-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = Layout<Shape <_128,_16>>; | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 2; | 
					
						
							|  |  |  |   using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * kAlignmentB>; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeB>, ElementB>{}, | 
					
						
							|  |  |  |                     Layout<Shape<_32,_8>>{}, | 
					
						
							|  |  |  |                     Layout<Shape< _2,_1>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     ElementA, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     ElementB, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							|  |  |  |     TagToStrideC_t<LayoutC>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere fp64 MMA TN (K-Major A and K-Major B)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     double, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _64, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,            // Atom
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Atom layout
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Val layout
 | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Layout<_2,_16>,Underscore>>; // Mode permutations
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A  (M,K)  K-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,0,4>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_4,_16>, | 
					
						
							|  |  |  |                          Stride<_1, _4>>{})); // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape < _8,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{}));                            // Value layout: 1x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B  (N,K)  K-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,0,4>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_4,_16>, | 
					
						
							|  |  |  |                          Stride<_1, _4>>{})); // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape < _8,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{}));                            // Value layout: 1x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     double, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     double, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<double, 1, double, double>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | /*
 | 
					
						
							|  |  |  |   using EpilogueOutputOp = epilogue::collective::Epilogue< | 
					
						
							|  |  |  |       epilogue::thread::LinearCombination<double, 1, double, double>, | 
					
						
							|  |  |  |       Layout<Shape <_64,_32>, | 
					
						
							|  |  |  |              Stride< _1,_64>>,                                           // SMEM layout
 | 
					
						
							|  |  |  |       Copy_Atom<UniversalCopy<double>,double>,                           // R2S with tiled_mma layout
 | 
					
						
							|  |  |  |       decltype(make_tiled_copy(Copy_Atom<UniversalCopy<double>,double>{},// S2R
 | 
					
						
							|  |  |  |                                Layout<Shape <_16,_16>, | 
					
						
							|  |  |  |                                       Stride< _1,_16>>{},                // Thread layout
 | 
					
						
							|  |  |  |                                Layout<Shape<_2,_1>>{})),                 // Value layout
 | 
					
						
							|  |  |  |       Copy_Atom<UniversalCopy<double>,double>                            // R2G with S2R_dst layout
 | 
					
						
							|  |  |  |       >; | 
					
						
							|  |  |  | */ | 
					
						
							|  |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere fp64 MMA NN (M-Major A and K-Major B)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _64, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,            // Atom
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Atom layout
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Val layout
 | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Layout<_2,_16>,Underscore>>; // Mode permutations
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A  (M,K)  M-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,2,2>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_16, _4>, | 
					
						
							|  |  |  |                          Stride< _1,_16>>{})); // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_2,_1>>{}));                            // Value layout: 2x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B  (N,K)  K-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,0,4>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_4,_16>, | 
					
						
							|  |  |  |                          Stride<_1, _4>>{}));// N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape < _8,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{}));                            // Value layout: 1x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     double, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     double, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<double, 1, double, double>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere fp64 MMA NT (M-Major A and N-Major B)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _64, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,            // Atom
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Atom layout
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Val layout
 | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Layout<_2,_16>,Underscore>>; // Mode permutations
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A  (M,K)  M-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,2,2>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_16, _4>, | 
					
						
							|  |  |  |                          Stride< _1,_16>>{})); // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_2,_1>>{}));                            // Value layout: 2x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B  (N,K)  N-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,2,2>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_16, _4>, | 
					
						
							|  |  |  |                          Stride< _1,_16>>{})); // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_2,_1>>{}));                            // Value layout: 2x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     double, TagToStrideA_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     double, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<double, 1, double, double>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Ampere fp64 MMA TT (K-Major A and N-Major B)
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm80, | 
					
						
							|  |  |  |     double, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _64, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>,            // Atom
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Atom layout
 | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>,                         // Val layout
 | 
					
						
							|  |  |  |       Tile<Layout<_2,_16>,Layout<_2,_16>,Underscore>>; // Mode permutations
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A  (M,K)  K-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,0,4>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_4,_16>, | 
					
						
							|  |  |  |                          Stride<_1, _4>>{})); // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 1; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape < _8,_16>, | 
					
						
							|  |  |  |                            Stride<_16, _1>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_1,_1>>{}));                            // Value layout: 1x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B  (N,K)  N-Major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							| 
									
										
										
										
											2023-08-08 08:50:32 +08:00
										 |  |  |       composition(Swizzle<2,2,2>{}, | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  |                   Layout<Shape <_16, _4>, | 
					
						
							|  |  |  |                          Stride< _1,_16>>{})); // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape <_16, _8>, | 
					
						
							|  |  |  |                            Stride< _1,_16>>{},                           // ThrLayout for CopyAtom
 | 
					
						
							|  |  |  |                     Layout<Shape<_2,_1>>{}));                            // Value layout: 2x1 doubles
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     double, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     double, TagToStrideB_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							|  |  |  |   using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TagToStrideC_t<cutlass::layout::ColumnMajor>, | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |     epilogue::thread::LinearCombination<double, 1, double, double>, | 
					
						
							|  |  |  |     cutlass::gemm::EpilogueDefault>; | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | // Hopper fp64 MMA TN
 | 
					
						
							|  |  |  | template <> | 
					
						
							|  |  |  | struct DefaultGemmConfigurationToCutlass3Types< | 
					
						
							|  |  |  |     arch::OpClassTensorOp, arch::Sm90, | 
					
						
							|  |  |  |     double, cutlass::layout::RowMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, | 
					
						
							|  |  |  |     double> | 
					
						
							|  |  |  | { | 
					
						
							|  |  |  |   using TileShape = Shape<_128, _64, _16>; | 
					
						
							|  |  |  |   static constexpr int ThreadCount = 128; | 
					
						
							|  |  |  |   using DispatchPolicy = MainloopSm80CpAsync<3>; | 
					
						
							|  |  |  |   using TiledMma = TiledMMA< | 
					
						
							|  |  |  |       MMA_Atom<SM90_16x8x16_F64F64F64F64_TN>, | 
					
						
							|  |  |  |       Layout<Shape<_2,_2,_1>>>; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // A (M,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomA = decltype( | 
					
						
							|  |  |  |     make_ordered_layout(Shape<_128,_16>{}, | 
					
						
							|  |  |  |                         Step <  _2, _1>{})); // M, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentA = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyA = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_2>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // B (N,K)  K-major
 | 
					
						
							|  |  |  |   using SmemLayoutAtomB = decltype( | 
					
						
							|  |  |  |     make_ordered_layout(Shape<_64,_16>{}, | 
					
						
							|  |  |  |                         Step < _2, _1>{}));                       // N, K
 | 
					
						
							|  |  |  |   using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>; | 
					
						
							|  |  |  |   static constexpr int kAlignmentB = 2; | 
					
						
							|  |  |  |   using GmemTiledCopyB = decltype( | 
					
						
							|  |  |  |     make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, | 
					
						
							|  |  |  |                     Layout<Shape <_16,_8>, | 
					
						
							|  |  |  |                            Stride< _8,_1>>{}, | 
					
						
							|  |  |  |                     Layout<Shape < _1,_2>>{})); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Mainloop
 | 
					
						
							|  |  |  |   using CollectiveMainloop = collective::CollectiveMma< | 
					
						
							|  |  |  |     DispatchPolicy, TileShape, | 
					
						
							|  |  |  |     double, TagToStrideA_t<cutlass::layout::RowMajor>, | 
					
						
							|  |  |  |     double, TagToStrideB_t<cutlass::layout::ColumnMajor>, | 
					
						
							|  |  |  |     TiledMma, | 
					
						
							|  |  |  |     GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,  // A
 | 
					
						
							|  |  |  |     GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity   // B
 | 
					
						
							|  |  |  |   >; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   // Epilogue
 | 
					
						
							| 
									
										
										
										
											2023-04-15 11:19:34 +08:00
										 |  |  |   using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< | 
					
						
							|  |  |  |     cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, | 
					
						
							|  |  |  |     TileShape, Shape<_1,_1,_1>, | 
					
						
							|  |  |  |     cutlass::epilogue::collective::EpilogueTileAuto, | 
					
						
							|  |  |  |     double, double, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, 1, | 
					
						
							|  |  |  |     double, cutlass::layout::ColumnMajor, 1, | 
					
						
							|  |  |  |     cutlass::epilogue::collective::EpilogueScheduleAuto | 
					
						
							|  |  |  |   >::CollectiveOp; | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-01-24 09:55:28 +08:00
										 |  |  | }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | ///////////////////////////////////////////////////////////////////////////////
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | } // namespace device
 | 
					
						
							|  |  |  | } // namespace gemm
 | 
					
						
							|  |  |  | } // namespace cutlass
 |