
* Updates for 3.2.1 release. * Minor fix in gemm op profiler for raster order. * Add scheduler mapping for raster order in the kernels.
267 lines
10 KiB
C++
267 lines
10 KiB
C++
/***************************************************************************************************
|
|
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
* SPDX-License-Identifier: BSD-3-Clause
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
*
|
|
* 1. Redistributions of source code must retain the above copyright notice, this
|
|
* list of conditions and the following disclaimer.
|
|
*
|
|
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
* this list of conditions and the following disclaimer in the documentation
|
|
* and/or other materials provided with the distribution.
|
|
*
|
|
* 3. Neither the name of the copyright holder nor the names of its
|
|
* contributors may be used to endorse or promote products derived from
|
|
* this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/kernel_hardware_info.hpp"
|
|
#include "cutlass/gemm/gemm.h"
|
|
#include "cutlass/gemm/dispatch_policy.hpp"
|
|
|
|
#include "cute/tensor.hpp"
|
|
|
|
#include "gather_tensor.hpp"
|
|
|
|
namespace cutlass::gemm::kernel {
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
template <
|
|
class ProblemShape_,
|
|
class CollectiveMainloop_,
|
|
class CollectiveEpilogue_,
|
|
class GatherA_,
|
|
class GatherB_,
|
|
class TileScheduler_ = void
|
|
>
|
|
class GemmGather
|
|
{
|
|
public:
|
|
//
|
|
// Type Aliases
|
|
//
|
|
using ProblemShape = ProblemShape_;
|
|
using TileSchedulerTag = TileScheduler_;
|
|
using TileScheduler = TileScheduler_;
|
|
static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4,
|
|
"ProblemShape{} should be <M,N,K> or <M,N,K,L>");
|
|
|
|
// Mainloop derived types
|
|
using CollectiveMainloop = CollectiveMainloop_;
|
|
using TileShape = typename CollectiveMainloop::TileShape;
|
|
using TiledMma = typename CollectiveMainloop::TiledMma;
|
|
using ArchTag = typename CollectiveMainloop::ArchTag;
|
|
using ElementA = typename CollectiveMainloop::ElementA;
|
|
using StrideA = typename CollectiveMainloop::StrideA;
|
|
using ElementB = typename CollectiveMainloop::ElementB;
|
|
using StrideB = typename CollectiveMainloop::StrideB;
|
|
using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy;
|
|
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
|
|
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
|
using MainloopParams = typename CollectiveMainloop::Params;
|
|
|
|
// Epilogue derived types
|
|
using CollectiveEpilogue = CollectiveEpilogue_;
|
|
using ElementC = typename CollectiveEpilogue::ElementC;
|
|
using StrideC = typename CollectiveEpilogue::StrideC;
|
|
using ElementD = typename CollectiveEpilogue::ElementD;
|
|
using StrideD = typename CollectiveEpilogue::StrideD;
|
|
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
|
using EpilogueParams = typename CollectiveEpilogue::Params;
|
|
static_assert(std::is_same_v<ElementAccumulator, typename CollectiveEpilogue::ElementAccumulator>,
|
|
"Mainloop and epilogue do not agree on accumulator value type.");
|
|
|
|
using GatherA = GatherA_;
|
|
using GatherB = GatherB_;
|
|
|
|
static constexpr int SharedStorageSize = static_cast<int>(cute::max(
|
|
sizeof(typename CollectiveMainloop::SharedStorage),
|
|
sizeof(typename CollectiveEpilogue::SharedStorage)));
|
|
|
|
static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{});
|
|
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
|
|
|
// Device side arguments
|
|
struct Arguments {
|
|
GemmUniversalMode mode{};
|
|
ProblemShape problem_shape{};
|
|
MainloopArguments mainloop{};
|
|
EpilogueArguments epilogue{};
|
|
KernelHardwareInfo hw_info{};
|
|
GatherA gather_A{};
|
|
GatherB gather_B{};
|
|
};
|
|
|
|
// Kernel entry point API
|
|
struct Params {
|
|
GemmUniversalMode mode;
|
|
ProblemShape problem_shape;
|
|
MainloopParams mainloop;
|
|
EpilogueParams epilogue;
|
|
GatherA gather_A{};
|
|
GatherB gather_B{};
|
|
};
|
|
|
|
//
|
|
// Methods
|
|
//
|
|
|
|
// Convert to underlying arguments.
|
|
static
|
|
Params
|
|
to_underlying_arguments(Arguments const& args, void* workspace) {
|
|
(void) workspace;
|
|
return {
|
|
args.mode,
|
|
args.problem_shape,
|
|
CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace),
|
|
CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace),
|
|
args.gather_A,
|
|
args.gather_B
|
|
};
|
|
}
|
|
|
|
static
|
|
Status
|
|
initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
|
return Status::kSuccess;
|
|
}
|
|
|
|
static
|
|
bool
|
|
can_implement(Arguments const& args) {
|
|
return args.mode == GemmUniversalMode::kGemm or
|
|
(args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4);
|
|
}
|
|
|
|
static
|
|
int
|
|
get_workspace_size(Arguments const& args) {
|
|
return 0;
|
|
}
|
|
|
|
static constexpr
|
|
dim3
|
|
get_grid_shape(Params const& params) {
|
|
int batch_count = 1;
|
|
if constexpr (rank(ProblemShape{}) == 4) {
|
|
batch_count = cute::size<3>(params.problem_shape);
|
|
}
|
|
|
|
return dim3(
|
|
cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))),
|
|
cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))),
|
|
batch_count
|
|
);
|
|
}
|
|
|
|
static constexpr
|
|
dim3
|
|
get_block_shape() {
|
|
return dim3(MaxThreadsPerBlock, 1, 1);
|
|
}
|
|
|
|
CUTLASS_DEVICE
|
|
void
|
|
operator()(Params const& params, char* smem_buf) {
|
|
using namespace cute;
|
|
using X = Underscore;
|
|
|
|
// Preconditions
|
|
CUTE_STATIC_ASSERT(is_static<TileShape>::value);
|
|
|
|
// Separate out problem shape for convenience
|
|
// Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK)
|
|
auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{});
|
|
auto M = get<0>(problem_shape_MNKL);
|
|
auto N = get<1>(problem_shape_MNKL);
|
|
auto K = get<2>(problem_shape_MNKL);
|
|
auto L = get<3>(problem_shape_MNKL);
|
|
|
|
// Preconditions
|
|
static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
|
static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>.");
|
|
static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
|
static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>.");
|
|
|
|
// Get the appropriate blocks for this thread block -- potential for thread block locality
|
|
int thread_idx = int(threadIdx.x);
|
|
auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K)
|
|
auto [m_coord, n_coord, l_coord] = blockIdx;
|
|
auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l)
|
|
|
|
// Represent the full tensors
|
|
Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l)
|
|
Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l)
|
|
|
|
// Get batch slice
|
|
Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k)
|
|
Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k)
|
|
|
|
// Slice to get the tiles this thread block is responsible for
|
|
Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k)
|
|
Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
|
|
|
|
// Compute tile residues for predication
|
|
auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord
|
|
auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord
|
|
auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max
|
|
auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue);
|
|
|
|
// Allocate the tiled_mma and the accumulators for the (M,N) blk_shape
|
|
TiledMma tiled_mma;
|
|
Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N)
|
|
clear(accumulators);
|
|
|
|
auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA));
|
|
int k_tile_count = size<2>(gA);
|
|
|
|
// Perform the collective scoped MMA
|
|
CollectiveMainloop collective_mma;
|
|
collective_mma(
|
|
accumulators,
|
|
gA,
|
|
gB,
|
|
accumulators,
|
|
k_tile_iter, k_tile_count,
|
|
residue_mnk,
|
|
thread_idx,
|
|
smem_buf
|
|
);
|
|
|
|
// Epilogue and write to gD
|
|
CollectiveEpilogue epilogue{params.epilogue};
|
|
epilogue(
|
|
problem_shape_MNKL,
|
|
blk_shape,
|
|
blk_coord_mnkl,
|
|
accumulators,
|
|
tiled_mma,
|
|
residue_mnk,
|
|
thread_idx,
|
|
smem_buf
|
|
);
|
|
}
|
|
};
|
|
|
|
///////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace cutlass::gemm::kernel
|