
* CUTLASS 3.4.0 * Update CHANGELOG.md --------- Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
1369 lines
52 KiB
C++
1369 lines
52 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cute/atom/mma_atom.hpp"
|
|
#include "cute/atom/copy_atom.hpp"
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/arch/arch.h"
|
|
#include "cutlass/arch/mma.h"
|
|
#include "cutlass/layout/layout.h"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
#include "cutlass/gemm/collective/collective_mma.hpp"
|
|
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
|
|
|
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
|
#include "cutlass/epilogue/thread/linear_combination.h"
|
|
|
|
namespace cutlass {
|
|
namespace gemm {
|
|
namespace device {
|
|
using namespace cute;
|
|
|
|
// This type is only intended to demonstrate porting 2.x kernels to 3.0
|
|
template<
|
|
class OperatorClass, class ArchTag,
|
|
class ElementA, class LayoutA,
|
|
class ElementB, class LayoutB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types {
|
|
static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists.");
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
|
|
template <typename Element, typename Layout, int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandA;
|
|
|
|
template <typename Element, typename Layout, int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandB;
|
|
|
|
//
|
|
// F16: 128-by-128-by-64
|
|
//
|
|
|
|
/// Operand A - Row-major (K-Major)
|
|
template <>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor, 8, 64>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<3,3,3>{},
|
|
Layout<Shape < _8,_64>,
|
|
Stride<_64, _1>>{}));
|
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;
|
|
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{},
|
|
Layout<Shape <_16,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_8>>{}));
|
|
};
|
|
|
|
/// Operand A - Column-major (M-major)
|
|
template <int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::ColumnMajor, 8, SizeK>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<3,3,3>{},
|
|
Layout<Shape <_64, _8>,
|
|
Stride< _1,_64>>{}));
|
|
using SmemCopyAtom = Copy_Atom<SM75_U16x8_LDSM_T, half_t>;
|
|
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{},
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{},
|
|
Layout<Shape < _8, _1>>{}));
|
|
};
|
|
|
|
// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
|
|
|
|
// Operand B - Column-Major (K-major)
|
|
template <int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandB<half_t, layout::ColumnMajor, Alignment, SizeK>
|
|
: DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor, Alignment, SizeK>
|
|
{};
|
|
|
|
// Operand B - Row-Major (N-major)
|
|
template <int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandB<half_t, layout::RowMajor, Alignment, SizeK>
|
|
: DefaultGemm_TensorOpSm80_OperandA<half_t, layout::ColumnMajor, Alignment, SizeK>
|
|
{};
|
|
|
|
//
|
|
// F16: 128-by-128-by-32 (small k-block)
|
|
//
|
|
|
|
/// Operand A - Row-major (K-Major)
|
|
template <>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<half_t, layout::RowMajor, 8, 32>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<2,3,3>{},
|
|
Layout<Shape < _8,_32>,
|
|
Stride<_32, _1>>{}));
|
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, half_t>;
|
|
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, half_t>{},
|
|
Layout<Shape <_32,_4>,
|
|
Stride< _4,_1>>{},
|
|
Layout<Shape < _1,_8>>{}));
|
|
};
|
|
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere MMA F32F16
|
|
template <typename LayoutA, typename LayoutB, typename LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
half_t, LayoutA,
|
|
half_t, LayoutB,
|
|
float, LayoutC,
|
|
float>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
|
Layout<Shape<_2,_2,_1>>, // 2x2x1 thread group
|
|
Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group
|
|
|
|
// A
|
|
static constexpr int kAlignmentA = 8;
|
|
using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA<
|
|
half_t, LayoutA, kAlignmentA, 32>;
|
|
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K
|
|
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
|
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
|
|
|
// B
|
|
static constexpr int kAlignmentB = 8;
|
|
using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB<
|
|
half_t, LayoutB, kAlignmentB, 32>;
|
|
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K
|
|
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
|
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
half_t, TagToStrideA_t<LayoutA>,
|
|
half_t, TagToStrideB_t<LayoutB>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<float, 1, float, float>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
|
|
//
|
|
// TF32: 128-by-128-by-kblock (kBlock = 16, 32)
|
|
//
|
|
|
|
/// Operand A - Row-major (K-major) (kBlock = 32)
|
|
template <>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor, 4, 32>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<3,2,3>{},
|
|
Layout<Shape < _8,_32>,
|
|
Stride<_32, _1>>{}));
|
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>;
|
|
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{},
|
|
Layout<Shape <_16,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
};
|
|
|
|
/// Operand A - Row-major (K-major) (kBlock = 16)
|
|
template <>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor, 4, 16>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<2,2,3>{},
|
|
Layout<Shape < _8,_16>,
|
|
Stride<_16, _1>>{}));
|
|
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, tfloat32_t>;
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{},
|
|
Layout<Shape <_32,_4>,
|
|
Stride< _4,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
};
|
|
|
|
/// Operand A - Column-major (M-major)
|
|
template <int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::ColumnMajor, 4, SizeK>
|
|
{
|
|
// Smem
|
|
using SmemLayoutAtom = decltype(
|
|
composition(Swizzle<3,2,3>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{}));
|
|
using SmemCopyAtom = Copy_Atom<UniversalCopy<tfloat32_t>, tfloat32_t>;
|
|
// Gmem
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, tfloat32_t>{},
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{},
|
|
Layout<Shape < _4, _1>>{}));
|
|
};
|
|
|
|
// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands
|
|
|
|
// Operand B - Column-Major (K-major)
|
|
template <int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, layout::ColumnMajor, Alignment, SizeK>
|
|
: DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::RowMajor, Alignment, SizeK>
|
|
{};
|
|
|
|
// Operand B - Row-Major (N-major)
|
|
template <int Alignment, int SizeK>
|
|
struct DefaultGemm_TensorOpSm80_OperandB<tfloat32_t, layout::RowMajor, Alignment, SizeK>
|
|
: DefaultGemm_TensorOpSm80_OperandA<tfloat32_t, layout::ColumnMajor, Alignment, SizeK>
|
|
{};
|
|
|
|
}
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere MMA F32TF32
|
|
template <typename LayoutA, typename LayoutB, typename LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
tfloat32_t, LayoutA,
|
|
tfloat32_t, LayoutB,
|
|
float, LayoutC,
|
|
float>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_16x8x8_F32TF32TF32F32_TN>,
|
|
Layout<Shape<_2,_2,_1>, Stride<_2, _1, _1>>, // 2x2x1 thread group
|
|
Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group
|
|
|
|
// A
|
|
static constexpr int kAlignmentA = 4;
|
|
using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA<
|
|
tfloat32_t, LayoutA, kAlignmentA, 32>;
|
|
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K
|
|
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
|
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
|
|
|
// B
|
|
static constexpr int kAlignmentB = 4;
|
|
using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB<
|
|
tfloat32_t, LayoutB, kAlignmentB, 32>;
|
|
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K
|
|
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
|
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
tfloat32_t, TagToStrideA_t<LayoutA>,
|
|
tfloat32_t, TagToStrideB_t<LayoutB>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<float, 1, float, float>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
template <typename LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
int8_t, cutlass::layout::RowMajor,
|
|
int8_t, cutlass::layout::ColumnMajor,
|
|
int32_t, LayoutC,
|
|
int32_t>
|
|
{
|
|
using TileShape = Shape<_128, _128, _64>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_16x8x32_S32S8S8S32_TN>,
|
|
Layout<Shape<_2,_2,_1>>, // 2x2x1 thread group
|
|
Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group
|
|
|
|
// A (M,K) K-major
|
|
using SmemLayoutAtomA = decltype(
|
|
composition(
|
|
Swizzle<2,4,3>{},
|
|
Layout<Shape <_16,_64>,
|
|
Stride<_64, _1>>{}));
|
|
static constexpr int kAlignmentA = 16;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, int8_t>{},
|
|
Layout<Shape <_32,_4>,
|
|
Stride< _4,_1>>{},
|
|
Layout<Shape<_1,Int<kAlignmentA>>>{}));
|
|
// LDS.32- or LDSM-based copy atom
|
|
// using SmemCopyAtomA = Copy_Atom<DefaultCopy, uint8_t>;
|
|
using SmemCopyAtomA = Copy_Atom<SM75_U32x4_LDSM_N, uint8_t>; // LDSM works
|
|
|
|
// B (N,K) K-major
|
|
using SmemLayoutAtomB = decltype(
|
|
composition(
|
|
Swizzle<2,4,3>{},
|
|
Layout<Shape <_16,_64>,
|
|
Stride<_64, _1>>{}));
|
|
static constexpr int kAlignmentB = 16;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, int8_t>{},
|
|
Layout<Shape <_32,_4>,
|
|
Stride< _4,_1>>{},
|
|
Layout<Shape<_1,Int<kAlignmentB>>>{}));
|
|
|
|
// LDS.32- or LDSM-based copy atom
|
|
// using SmemCopyAtomB = Copy_Atom<DefaultCopy, uint32_t>;
|
|
using SmemCopyAtomB = Copy_Atom<SM75_U32x4_LDSM_N, uint8_t>; // LDSM works
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
int8_t, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
int8_t, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<int32_t, 1, int32_t, int32_t>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
//////////////////////////// SIMT TWO STAGE ///////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
namespace detail {
|
|
|
|
template <typename Element, typename Layout, int ThreadCount, int ShapeM, int ShapeK>
|
|
struct DefaultGemm_Simt_OperandA;
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <typename Element>
|
|
struct DefaultGemm_Simt_OperandA<Element, layout::ColumnMajor, 256, 128, 8>
|
|
{
|
|
using SmemLayoutAtom = Layout<Shape <_128, _8>,
|
|
Stride< _1,_128>>;
|
|
|
|
using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
|
|
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{},
|
|
Layout<Shape<_1,_1>>{}));
|
|
};
|
|
|
|
template <typename Element>
|
|
struct DefaultGemm_Simt_OperandA<Element, layout::RowMajor, 256, 128, 8>
|
|
{
|
|
using SmemLayoutAtom = Layout<Shape <_128, _8>,
|
|
Stride< _1,Int<128 + 4>>>; // Padded
|
|
|
|
using SmemCopyAtom = Copy_Atom<DefaultCopy, Element>;
|
|
|
|
using GmemTiledCopy = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _8, _1>>{},
|
|
Layout<Shape<_1,_1>>{}));
|
|
|
|
};
|
|
|
|
template <typename Element, typename Layout, int ThreadCount, int ShapeN, int ShapeK>
|
|
struct DefaultGemm_Simt_OperandB;
|
|
|
|
template <typename Element, int ThreadCount, int ShapeN, int ShapeK>
|
|
struct DefaultGemm_Simt_OperandB<Element, layout::ColumnMajor, ThreadCount, ShapeN, ShapeK>
|
|
: DefaultGemm_Simt_OperandA<Element, layout::RowMajor, ThreadCount, ShapeN, ShapeK> {};
|
|
|
|
template <typename Element, int ThreadCount, int ShapeN, int ShapeK>
|
|
struct DefaultGemm_Simt_OperandB<Element, layout::RowMajor, ThreadCount, ShapeN, ShapeK>
|
|
: DefaultGemm_Simt_OperandA<Element, layout::ColumnMajor, ThreadCount, ShapeN, ShapeK> {};
|
|
|
|
} // end namespace detail
|
|
|
|
// SIMT Two Stage
|
|
template <
|
|
class ArchTag,
|
|
class ElementA, class LayoutA,
|
|
class ElementB, class LayoutB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, ArchTag,
|
|
ElementA, LayoutA,
|
|
ElementB, LayoutB,
|
|
ElementC, LayoutC,
|
|
ElementAccumulator>
|
|
{
|
|
using TileShape = Shape<_128, _128, _8>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm70TwoStage;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>,
|
|
Layout<Shape<_16, _16, _1>>>;
|
|
|
|
// A
|
|
static constexpr int kAlignmentA = 1;
|
|
using DefaultOperandA = detail::DefaultGemm_Simt_OperandA<ElementA, LayoutA, ThreadCount, 128, 8>;
|
|
using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom;
|
|
using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom;
|
|
using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy;
|
|
|
|
// B
|
|
static constexpr int kAlignmentB = 1;
|
|
using DefaultOperandB = detail::DefaultGemm_Simt_OperandB<ElementB, LayoutB, ThreadCount, 128, 8>;
|
|
using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom;
|
|
using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom;
|
|
using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy;
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<LayoutA>,
|
|
ElementB, TagToStrideB_t<LayoutB>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
|
|
//
|
|
// DP4A - int8 Proof-of-concept
|
|
//
|
|
|
|
// SIMT Two Stage TN - idp4a
|
|
template <
|
|
class ArchTag,
|
|
class ElementC, class LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, ArchTag,
|
|
int8_t, cutlass::layout::RowMajor,
|
|
int8_t, cutlass::layout::ColumnMajor,
|
|
ElementC, LayoutC,
|
|
int32_t>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm70TwoStage;
|
|
// NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM61_DP4A>,
|
|
Layout<Shape<_16,_16,_1>>>; // Tile of atoms (threads)
|
|
|
|
// A (M,K) K-major
|
|
using ElementA = int8_t;
|
|
// 40% from regular M and N major layout
|
|
// using SmemLayoutAtomA = Layout<Shape <_128,_32>,
|
|
// Stride< _1,_128>>;
|
|
// 80% from interleaved layouts
|
|
using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 4;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementA>{},
|
|
Layout<Shape <_32,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
|
|
// B (N,K) K-major
|
|
using ElementB = int8_t;
|
|
// 40% from regular M and N major layout
|
|
// using SmemLayoutAtomB = Layout<Shape <_128,_32>,
|
|
// Stride< _1,_128>>;
|
|
// 80% from interleaved layouts
|
|
using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 4;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementB>{},
|
|
Layout<Shape <_32,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Two Stage NN - idp4a
|
|
template <
|
|
class ArchTag,
|
|
class ElementC, class LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, ArchTag,
|
|
int8_t, cutlass::layout::ColumnMajor,
|
|
int8_t, cutlass::layout::ColumnMajor,
|
|
ElementC, LayoutC,
|
|
int32_t>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 256;
|
|
|
|
using DispatchPolicy = MainloopSm70TwoStage;
|
|
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM61_DP4A>,
|
|
Layout<Shape<_16, _16, _1>>>;
|
|
|
|
// A (M,K) M-major
|
|
using ElementA = int8_t;
|
|
using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementA>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{},
|
|
Layout<Shape < _1, _1>>{}));
|
|
|
|
// B (N,K) K-major
|
|
using ElementB = int8_t;
|
|
using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 4;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementB>{},
|
|
Layout<Shape <_32,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Two Stage NT - idp4a
|
|
template <
|
|
class ArchTag,
|
|
class ElementC, class LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, ArchTag,
|
|
int8_t, cutlass::layout::ColumnMajor,
|
|
int8_t, cutlass::layout::RowMajor,
|
|
ElementC, LayoutC,
|
|
int32_t>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm70TwoStage;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM61_DP4A>,
|
|
Layout<Shape<_16, _16, _1>>>;
|
|
|
|
// A (M,K) M-major
|
|
using ElementA = int8_t;
|
|
using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementA>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{},
|
|
Layout<Shape < _1, _1>>{}));
|
|
|
|
// B (N,K) N-major
|
|
using ElementB = int8_t;
|
|
using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementB>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{},
|
|
Layout<Shape < _1, _1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Two Stage TT - idp4a
|
|
template <
|
|
class ArchTag,
|
|
class ElementC, class LayoutC>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, ArchTag,
|
|
int8_t, cutlass::layout::RowMajor,
|
|
int8_t, cutlass::layout::RowMajor,
|
|
ElementC, LayoutC,
|
|
int32_t>
|
|
{
|
|
using TileShape = Shape<_128, _128, _32>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm70TwoStage;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM61_DP4A>,
|
|
Layout<Shape<_16, _16, _1>>>;
|
|
|
|
// A (M,K) K-major
|
|
using ElementA = int8_t;
|
|
using SmemLayoutAtomA = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 4;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint32_t>, ElementA>{},
|
|
Layout<Shape <_32,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_4>>{}));
|
|
|
|
// B (N,K) N-major
|
|
using ElementB = int8_t;
|
|
using SmemLayoutAtomB = Layout<Shape <_128, Shape <_4, _8>>,
|
|
Stride< _4, Stride<_1,_512>>>;
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<UniversalCopy<cute::uint8_t>, ElementB>{},
|
|
Layout<Shape <_32, _8>,
|
|
Stride< _1,_32>>{},
|
|
Layout<Shape < _1, _1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, int32_t, int32_t>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
/////////////////////////// SIMT MULTI STAGE //////////////////////////////////
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Multi Stage NT
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, arch::Sm80,
|
|
ElementA, cutlass::layout::ColumnMajor,
|
|
ElementB, cutlass::layout::RowMajor,
|
|
ElementC, LayoutC,
|
|
ElementAccumulator>
|
|
{
|
|
using TileShape = Shape<_128, _128, _16>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>,
|
|
Layout<Shape<_16, _16, _1>>, // 16x16x1 thread group
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization
|
|
Layout<Shape<_16,_2>,Stride<_2,_1>>,Underscore>>;
|
|
|
|
// A (M,K) M-major
|
|
using SmemLayoutAtomA = Layout<Shape<_128,_16>>;
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 2;
|
|
using AlignmentTypeA = cute::uint_byte_t<static_cast<int>(sizeof(ElementA)) * kAlignmentA>;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeA>, ElementA>{},
|
|
Layout<Shape<_32,_8>>{},
|
|
Layout<Shape< _2,_1>>{}));
|
|
|
|
// B (N,K) N-major
|
|
using SmemLayoutAtomB = Layout<Shape<_128,_16>>;
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 2;
|
|
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * kAlignmentB>;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeB>, ElementB>{},
|
|
Layout<Shape<_32,_8>>{},
|
|
Layout<Shape< _2,_1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Multi Stage TN
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, arch::Sm80,
|
|
ElementA, cutlass::layout::RowMajor,
|
|
ElementB, cutlass::layout::ColumnMajor,
|
|
ElementC, LayoutC,
|
|
ElementAccumulator>
|
|
{
|
|
using TileShape = Shape<_128, _128, _16>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>,
|
|
Layout<Shape<_16, _16, _1>>>;
|
|
|
|
// A (M,K) K-major
|
|
using SmemLayoutAtomA = Layout<Shape <_128, _16>,
|
|
Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementA>, ElementA>{},
|
|
Layout<Shape <_16,_16>,
|
|
Stride<_16, _1>>{}));
|
|
|
|
// B (N,K) K-major
|
|
using SmemLayoutAtomB = Layout<Shape <_128, _16>,
|
|
Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementB>, ElementB>{},
|
|
Layout<Shape <_16,_16>,
|
|
Stride<_16, _1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Multi Stage NN
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, arch::Sm80,
|
|
ElementA, cutlass::layout::ColumnMajor,
|
|
ElementB, cutlass::layout::ColumnMajor,
|
|
ElementC, LayoutC,
|
|
ElementAccumulator>
|
|
{
|
|
using TileShape = Shape<_128, _128, _16>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>,
|
|
Layout<Shape<_16, _16, _1>>, // 16x16x1 thread group
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization
|
|
|
|
// A (M,K) M-major
|
|
using SmemLayoutAtomA = Layout<Shape<_128,_16>>;
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 2;
|
|
using AlignmentTypeA = cute::uint_byte_t<static_cast<int>(sizeof(ElementA)) * kAlignmentA>;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeA>, ElementA>{},
|
|
Layout<Shape<_32,_8>>{},
|
|
Layout<Shape< _2,_1>>{}));
|
|
|
|
// B (N,K) K-major
|
|
using SmemLayoutAtomB = Layout<Shape <_128, _16>,
|
|
Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementB>, ElementB>{},
|
|
Layout<Shape <_16,_16>,
|
|
Stride<_16, _1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// SIMT Multi Stage TT
|
|
template <
|
|
class ElementA,
|
|
class ElementB,
|
|
class ElementC, class LayoutC,
|
|
class ElementAccumulator>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassSimt, arch::Sm80,
|
|
ElementA, cutlass::layout::RowMajor,
|
|
ElementB, cutlass::layout::RowMajor,
|
|
ElementC, LayoutC,
|
|
ElementAccumulator>
|
|
{
|
|
using TileShape = Shape<_128, _128, _16>;
|
|
static constexpr int ThreadCount = 256;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<UniversalFMA<ElementAccumulator, ElementA, ElementB, ElementC>>,
|
|
Layout<Shape<_16, _16, _1>>, // 16x16x1 thread group
|
|
Tile<Underscore,Layout<Shape<_16,_2>,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization
|
|
|
|
// A (M,K) K-major
|
|
using SmemLayoutAtomA = Layout<Shape <_128, _16>,
|
|
Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, ElementA>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<ElementA>, ElementA>{},
|
|
Layout<Shape <_16,_16>,
|
|
Stride<_16, _1>>{}));
|
|
|
|
// B (N,K) N-major
|
|
using SmemLayoutAtomB = Layout<Shape <_128,_16>>;
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, ElementB>;
|
|
static constexpr int kAlignmentB = 2;
|
|
using AlignmentTypeB = cute::uint_byte_t<static_cast<int>(sizeof(ElementB)) * kAlignmentB>;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeB>, ElementB>{},
|
|
Layout<Shape<_32,_8>>{},
|
|
Layout<Shape< _2,_1>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
ElementA, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
ElementB, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<LayoutC>,
|
|
TagToStrideC_t<LayoutC>,
|
|
epilogue::thread::LinearCombination<ElementC, 1, ElementAccumulator, ElementAccumulator>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere fp64 MMA TN (K-Major A and K-Major B)
|
|
template <>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
double, cutlass::layout::RowMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double>
|
|
{
|
|
using TileShape = Shape<_128, _64, _16>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>, // Atom
|
|
Layout<Shape<_2,_2,_1>>, // Atom layout
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization
|
|
Layout<Shape<_16,_2>,Stride<_2,_1>>,
|
|
Underscore>>;
|
|
|
|
// A (M,K) K-Major
|
|
using SmemLayoutAtomA = decltype(
|
|
composition(Swizzle<2,0,4>{},
|
|
Layout<Shape <_4,_16>,
|
|
Stride<_1, _4>>{})); // M, K
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
|
|
Layout<Shape < _8,_16>,
|
|
Stride<_16, _1>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_1,_1>>{})); // Value layout: 1x1 doubles
|
|
|
|
// B (N,K) K-Major
|
|
using SmemLayoutAtomB = decltype(
|
|
composition(Swizzle<2,0,4>{},
|
|
Layout<Shape <_4,_16>,
|
|
Stride<_1, _4>>{})); // N, K
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
|
|
Layout<Shape < _8,_16>,
|
|
Stride<_16, _1>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_1,_1>>{})); // Value layout: 1x1 doubles
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
double, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
double, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
epilogue::thread::LinearCombination<double, 1, double, double>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
|
|
/*
|
|
using EpilogueOutputOp = epilogue::collective::Epilogue<
|
|
epilogue::thread::LinearCombination<double, 1, double, double>,
|
|
Layout<Shape <_64,_32>,
|
|
Stride< _1,_64>>, // SMEM layout
|
|
Copy_Atom<UniversalCopy<double>,double>, // R2S with tiled_mma layout
|
|
decltype(make_tiled_copy(Copy_Atom<UniversalCopy<double>,double>{},// S2R
|
|
Layout<Shape <_16,_16>,
|
|
Stride< _1,_16>>{}, // Thread layout
|
|
Layout<Shape<_2,_1>>{})), // Value layout
|
|
Copy_Atom<UniversalCopy<double>,double> // R2G with S2R_dst layout
|
|
>;
|
|
*/
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere fp64 MMA NN (M-Major A and K-Major B)
|
|
template <>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double>
|
|
{
|
|
using TileShape = Shape<_128, _64, _16>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>, // Atom
|
|
Layout<Shape<_2,_2,_1>>, // Atom layout
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization
|
|
Layout<Shape<_16,_2>,Stride<_2,_1>>,
|
|
Underscore>>;
|
|
|
|
// A (M,K) M-Major
|
|
using SmemLayoutAtomA = decltype(
|
|
composition(Swizzle<2,2,2>{},
|
|
Layout<Shape <_16, _4>,
|
|
Stride< _1,_16>>{})); // M, K
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentA = 2;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_2,_1>>{})); // Value layout: 2x1 doubles
|
|
|
|
// B (N,K) K-Major
|
|
using SmemLayoutAtomB = decltype(
|
|
composition(Swizzle<2,0,4>{},
|
|
Layout<Shape <_4,_16>,
|
|
Stride<_1, _4>>{}));// N, K
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentB = 1;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
|
|
Layout<Shape < _8,_16>,
|
|
Stride<_16, _1>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_1,_1>>{})); // Value layout: 1x1 doubles
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
double, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
double, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
epilogue::thread::LinearCombination<double, 1, double, double>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere fp64 MMA NT (M-Major A and N-Major B)
|
|
template <>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double, cutlass::layout::RowMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double>
|
|
{
|
|
using TileShape = Shape<_128, _64, _16>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>, // Atom
|
|
Layout<Shape<_2,_2,_1>>, // Atom layout
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization
|
|
Layout<Shape<_16,_2>,Stride<_2,_1>>,
|
|
Underscore>>;
|
|
|
|
// A (M,K) M-Major
|
|
using SmemLayoutAtomA = decltype(
|
|
composition(Swizzle<2,2,2>{},
|
|
Layout<Shape <_16, _4>,
|
|
Stride< _1,_16>>{})); // M, K
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentA = 2;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_2,_1>>{})); // Value layout: 2x1 doubles
|
|
|
|
// B (N,K) N-Major
|
|
using SmemLayoutAtomB = decltype(
|
|
composition(Swizzle<2,2,2>{},
|
|
Layout<Shape <_16, _4>,
|
|
Stride< _1,_16>>{})); // N, K
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentB = 2;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_2,_1>>{})); // Value layout: 2x1 doubles
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
double, TagToStrideA_t<cutlass::layout::ColumnMajor>,
|
|
double, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
epilogue::thread::LinearCombination<double, 1, double, double>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Ampere fp64 MMA TT (K-Major A and N-Major B)
|
|
template <>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm80,
|
|
double, cutlass::layout::RowMajor,
|
|
double, cutlass::layout::RowMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double>
|
|
{
|
|
using TileShape = Shape<_128, _64, _16>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM80_8x8x4_F64F64F64F64_TN>, // Atom
|
|
Layout<Shape<_2,_2,_1>>, // Atom layout
|
|
Tile<Layout<Shape<_16,_2>,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization
|
|
Layout<Shape<_16,_2>,Stride<_2,_1>>,
|
|
Underscore>>;
|
|
|
|
// A (M,K) K-Major
|
|
using SmemLayoutAtomA = decltype(
|
|
composition(Swizzle<2,0,4>{},
|
|
Layout<Shape <_4,_16>,
|
|
Stride<_1, _4>>{})); // M, K
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentA = 1;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<double>, double>{}, // CopyAtom
|
|
Layout<Shape < _8,_16>,
|
|
Stride<_16, _1>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_1,_1>>{})); // Value layout: 1x1 doubles
|
|
|
|
// B (N,K) N-Major
|
|
using SmemLayoutAtomB = decltype(
|
|
composition(Swizzle<2,2,2>{},
|
|
Layout<Shape <_16, _4>,
|
|
Stride< _1,_16>>{})); // N, K
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentB = 2;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{}, // CopyAtom
|
|
Layout<Shape <_16, _8>,
|
|
Stride< _1,_16>>{}, // ThrLayout for CopyAtom
|
|
Layout<Shape<_2,_1>>{})); // Value layout: 2x1 doubles
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
double, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
double, TagToStrideB_t<cutlass::layout::RowMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = epilogue::collective::DefaultEpilogue<
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
TagToStrideC_t<cutlass::layout::ColumnMajor>,
|
|
epilogue::thread::LinearCombination<double, 1, double, double>,
|
|
cutlass::gemm::EpilogueDefault>;
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Hopper fp64 MMA TN
|
|
template <>
|
|
struct DefaultGemmConfigurationToCutlass3Types<
|
|
arch::OpClassTensorOp, arch::Sm90,
|
|
double, cutlass::layout::RowMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double, cutlass::layout::ColumnMajor,
|
|
double>
|
|
{
|
|
using TileShape = Shape<_128, _64, _16>;
|
|
static constexpr int ThreadCount = 128;
|
|
using DispatchPolicy = MainloopSm80CpAsync<3>;
|
|
using TiledMma = TiledMMA<
|
|
MMA_Atom<SM90_16x8x16_F64F64F64F64_TN>,
|
|
Layout<Shape<_2,_2,_1>>>;
|
|
|
|
// A (M,K) K-major
|
|
using SmemLayoutAtomA = decltype(
|
|
make_ordered_layout(Shape<_128,_16>{},
|
|
Step < _2, _1>{})); // M, K
|
|
using SmemCopyAtomA = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentA = 2;
|
|
using GmemTiledCopyA = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{},
|
|
Layout<Shape <_16,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_2>>{}));
|
|
|
|
// B (N,K) K-major
|
|
using SmemLayoutAtomB = decltype(
|
|
make_ordered_layout(Shape<_64,_16>{},
|
|
Step < _2, _1>{})); // N, K
|
|
using SmemCopyAtomB = Copy_Atom<DefaultCopy, double>;
|
|
static constexpr int kAlignmentB = 2;
|
|
using GmemTiledCopyB = decltype(
|
|
make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, double>{},
|
|
Layout<Shape <_16,_8>,
|
|
Stride< _8,_1>>{},
|
|
Layout<Shape < _1,_2>>{}));
|
|
|
|
// Mainloop
|
|
using CollectiveMainloop = collective::CollectiveMma<
|
|
DispatchPolicy, TileShape,
|
|
double, TagToStrideA_t<cutlass::layout::RowMajor>,
|
|
double, TagToStrideB_t<cutlass::layout::ColumnMajor>,
|
|
TiledMma,
|
|
GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A
|
|
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B
|
|
>;
|
|
|
|
// Epilogue
|
|
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
|
|
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
|
|
TileShape, Shape<_1,_1,_1>,
|
|
cutlass::epilogue::collective::EpilogueTileAuto,
|
|
double, double,
|
|
double, cutlass::layout::ColumnMajor, 1,
|
|
double, cutlass::layout::ColumnMajor, 1,
|
|
cutlass::epilogue::collective::EpilogueScheduleAuto
|
|
>::CollectiveOp;
|
|
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace device
|
|
} // namespace gemm
|
|
} // namespace cutlass
|