cutlass/include/cutlass/gemm/dispatch_policy.hpp
2023-04-29 09:34:27 -04:00

171 lines
6.1 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/arch.h"
#include "cute/layout.hpp"
#include "cute/numeric/integral_constant.hpp"
//////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm {
using namespace cute;
//////////////////////////////////////////////////////////////////////////////
//
// Policies for categorical dispatch of mainloop against kernel grid schedules
//
struct KernelMultistage { };
struct KernelTma { };
struct KernelTmaWarpSpecialized { };
struct KernelTmaWarpSpecializedPingpong { };
struct KernelTmaWarpSpecializedCooperative { };
// Policies for dispatch of epilogue
struct EpilogueDefault { };
struct EpilogueTransposed { };
//
// Collective Mainloop Policies
//
// 2 stage pipeline through 1 stage in smem, 1 in rmem, WITHOUT predicated gmem loads
struct MainloopSm70TwoStageUnpredicated {
constexpr static int Stages = 2;
using ArchTag = arch::Sm70;
using Schedule = KernelMultistage;
using ClusterShape = Shape<_1,_1,_1>;
};
// 2 stage pipeline through 1 stage in smem, 1 in rmem, with predicated gmem loads
struct MainloopSm70TwoStage {
constexpr static int Stages = 2;
using ArchTag = arch::Sm70;
using Schedule = KernelMultistage;
using ClusterShape = Shape<_1,_1,_1>;
};
// n-buffer in smem (cp.async), pipelined with registers, WITHOUT predicated gmem loads
template<int Stages_>
struct MainloopSm80CpAsyncUnpredicated {
constexpr static int Stages = Stages_;
using ArchTag = arch::Sm80;
using Schedule = KernelMultistage;
using ClusterShape = Shape<_1,_1,_1>;
};
// n-buffer in smem (cp.async), pipelined with registers, with predicated gmem loads
template<int Stages_>
struct MainloopSm80CpAsync {
constexpr static int Stages = Stages_;
using ArchTag = arch::Sm80;
using Schedule = KernelMultistage;
using ClusterShape = Shape<_1,_1,_1>;
};
// n-buffer in smem (cp.async), pipelined with Hopper GMMA, WITHOUT predicated gmem loads
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>
>
struct MainloopSm90CpAsyncGmmaUnpredicated {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelMultistage;
};
// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>
>
struct MainloopSm90CpAsyncGmma {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelMultistage;
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
int PipelineAsyncMmaStages_ = 1
>
struct MainloopSm90TmaGmma {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
using ArchTag = arch::Sm90;
using Schedule = KernelTma;
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecializedCooperative
>
struct MainloopSm90TmaGmmaWarpSpecialized {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
};
// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule
// With GMMA's A data from registers.
template<
int Stages_,
class ClusterShape_ = Shape<_1,_1,_1>,
class KernelSchedule = KernelTmaWarpSpecialized,
int PipelineAsyncMmaStages_ = 0
>
struct MainloopSm90TmaGmmaRmemAWarpSpecialized {
constexpr static int Stages = Stages_;
using ClusterShape = ClusterShape_;
constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_;
using ArchTag = arch::Sm90;
using Schedule = KernelSchedule;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative>,
"KernelSchedule must be one of the warp specialized policies");
};
//////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm