cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
ANIKET SHIVAM 90d3b0fb18
CUTLASS 3.2.1 (#1113)
* Updates for 3.2.1 release.

* Minor fix in gemm op profiler for raster order.

* Add scheduler mapping for raster order in the kernels.
2023-09-26 17:24:26 -04:00

267 lines
10 KiB
C++

/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/tensor.hpp"
#include "gather_tensor.hpp"
namespace cutlass::gemm::kernel {
///////////////////////////////////////////////////////////////////////////////
template <
class ProblemShape_,
class CollectiveMainloop_,
class CollectiveEpilogue_,
class GatherA_,
class GatherB_,
class TileScheduler_ = void
>
class GemmGather
{
public:
//
// Type Aliases
//
using ProblemShape = ProblemShape_;
using TileSchedulerTag = TileScheduler_;
using TileScheduler = TileScheduler_;
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
// Mainloop derived types
using CollectiveMainloop = CollectiveMainloop_;
using TileShape = typename CollectiveMainloop::TileShape;
using TiledMma = typename CollectiveMainloop::TiledMma;
using ArchTag = typename CollectiveMainloop::ArchTag;
using ElementA = typename CollectiveMainloop::ElementA;
using StrideA = typename CollectiveMainloop::StrideA;
using ElementB = typename CollectiveMainloop::ElementB;
using StrideB = typename CollectiveMainloop::StrideB;
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
// Epilogue derived types
using CollectiveEpilogue = CollectiveEpilogue_;
using ElementC = typename CollectiveEpilogue::ElementC;
using StrideC = typename CollectiveEpilogue::StrideC;
using ElementD = typename CollectiveEpilogue::ElementD;
using StrideD = typename CollectiveEpilogue::StrideD;
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
using EpilogueParams = typename CollectiveEpilogue::Params;
static_assert(std::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
"Mainloop and epilogue do not agree on accumulator value type.");
using GatherA = GatherA_;
using GatherB = GatherB_;
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
sizeof(typename CollectiveMainloop::SharedStorage),
sizeof(typename CollectiveEpilogue::SharedStorage)));
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{});
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
// Device side arguments
struct Arguments {
GemmUniversalMode mode{};
ProblemShape problem_shape{};
MainloopArguments mainloop{};
EpilogueArguments epilogue{};
KernelHardwareInfo hw_info{};
GatherA gather_A{};
GatherB gather_B{};
};
// Kernel entry point API
struct Params {
GemmUniversalMode mode;
ProblemShape problem_shape;
MainloopParams mainloop;
EpilogueParams epilogue;
GatherA gather_A{};
GatherB gather_B{};
};
//
// Methods
//
// Convert to underlying arguments.
static
Params
to_underlying_arguments(Arguments const& args, void* workspace) {
(void) workspace;
return {
args.mode,
args.problem_shape,
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
args.gather_A,
args.gather_B
};
}
static
Status
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return Status::kSuccess;
}
static
bool
can_implement(Arguments const& args) {
return args.mode == GemmUniversalMode::kGemm or
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
}
static
int
get_workspace_size(Arguments const& args) {
return 0;
}
static constexpr
dim3
get_grid_shape(Params const& params) {
int batch_count = 1;
if constexpr (rank(ProblemShape{}) == 4) {
batch_count = cute::size<3>(params.problem_shape);
}
return dim3(
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
batch_count
);
}
static constexpr
dim3
get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
CUTLASS_DEVICE
void
operator()(Params const& params, char* smem_buf) {
using namespace cute;
using X = Underscore;
// Preconditions
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
// Separate out problem shape for convenience
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
auto M = get<0>(problem_shape_MNKL);
auto N = get<1>(problem_shape_MNKL);
auto K = get<2>(problem_shape_MNKL);
auto L = get<3>(problem_shape_MNKL);
// Preconditions
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
// Get the appropriate blocks for this thread block -- potential for thread block locality
int thread_idx = int(threadIdx.x);
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
auto [m_coord, n_coord, l_coord] = blockIdx;
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l)
// Represent the full tensors
Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
// Get batch slice
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
// Slice to get the tiles this thread block is responsible for
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
// Compute tile residues for predication
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
TiledMma tiled_mma;
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
clear(accumulators);
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
int k_tile_count = size<2>(gA);
// Perform the collective scoped MMA
CollectiveMainloop collective_mma;
collective_mma(
accumulators,
gA,
gB,
accumulators,
k_tile_iter, k_tile_count,
residue_mnk,
thread_idx,
smem_buf
);
// Epilogue and write to gD
CollectiveEpilogue epilogue{params.epilogue};
epilogue(
problem_shape_MNKL,
blk_shape,
blk_coord_mnkl,
accumulators,
tiled_mma,
residue_mnk,
thread_idx,
smem_buf
);
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::kernel