440 lines
15 KiB
C
440 lines
15 KiB
C
![]() |
/***************************************************************************************************
|
||
|
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
|
||
|
*
|
||
|
* Redistribution and use in source and binary forms, with or without modification, are permitted
|
||
|
* provided that the following conditions are met:
|
||
|
* * Redistributions of source code must retain the above copyright notice, this list of
|
||
|
* conditions and the following disclaimer.
|
||
|
* * Redistributions in binary form must reproduce the above copyright notice, this list of
|
||
|
* conditions and the following disclaimer in the documentation and/or other materials
|
||
|
* provided with the distribution.
|
||
|
* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
|
||
|
* to endorse or promote products derived from this software without specific prior written
|
||
|
* permission.
|
||
|
*
|
||
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
||
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
||
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
||
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
||
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
||
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
||
|
* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
*
|
||
|
**************************************************************************************************/
|
||
|
/*! \file
|
||
|
\brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K.
|
||
|
*/
|
||
|
|
||
|
#pragma once
|
||
|
|
||
|
#include "cutlass/cutlass.h"
|
||
|
#include "cutlass/numeric_types.h"
|
||
|
#include "cutlass/arch/arch.h"
|
||
|
#include "cutlass/device_kernel.h"
|
||
|
|
||
|
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||
|
|
||
|
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||
|
#include "cutlass/epilogue/thread/linear_combination_relu.h"
|
||
|
|
||
|
#include "kernel/b2b_gemm.h"
|
||
|
#include "kernel/default_b2b_gemm.h"
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|
||
|
|
||
|
namespace cutlass {
|
||
|
namespace gemm {
|
||
|
namespace device {
|
||
|
|
||
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||
|
|
||
|
template <
|
||
|
/// Element type for A matrix operand
|
||
|
typename ElementA_,
|
||
|
/// Layout type for A matrix operand
|
||
|
typename LayoutA_,
|
||
|
/// Element type for B matrix operand
|
||
|
typename ElementB_,
|
||
|
/// Layout type for B matrix operand
|
||
|
typename LayoutB_,
|
||
|
/// Element type for C and D matrix operands
|
||
|
typename ElementC_,
|
||
|
/// Layout type for C and D matrix operands
|
||
|
typename LayoutC_,
|
||
|
/// Element type for internal accumulation
|
||
|
typename ElementAccumulator_ = ElementC_,
|
||
|
/// Operator class tag
|
||
|
typename OperatorClass_ = arch::OpClassSimt,
|
||
|
/// Tag indicating architecture to tune for
|
||
|
typename ArchTag_ = arch::Sm70,
|
||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||
|
typename ThreadblockShape0_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::ThreadblockShape,
|
||
|
/// Threadblock-level tile size (concept: GemmShape)
|
||
|
typename ThreadblockShape1_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::ThreadblockShape,
|
||
|
/// Warp-level tile size (concept: GemmShape)
|
||
|
typename WarpShape0_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::WarpShape,
|
||
|
/// Warp-level tile size (concept: GemmShape)
|
||
|
typename WarpShape1_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::WarpShape,
|
||
|
/// Instruction-level tile size (concept: GemmShape)
|
||
|
typename InstructionShape_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::InstructionShape,
|
||
|
/// Epilogue output operator
|
||
|
typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::EpilogueOutputOp,
|
||
|
/// Epilogue output operator
|
||
|
typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::EpilogueOutputOp,
|
||
|
/// Threadblock-level swizzling operator
|
||
|
typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>,
|
||
|
/// Number of stages used in the pipelined mainloop
|
||
|
int Stages =
|
||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||
|
ElementC_, ElementAccumulator_>::kStages,
|
||
|
/// Access granularity of A matrix in units of elements
|
||
|
int AlignmentA =
|
||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||
|
ElementC_, ElementAccumulator_>::kAlignmentA,
|
||
|
/// Access granularity of B matrix in units of elements
|
||
|
int AlignmentB =
|
||
|
DefaultGemmConfiguration<OperatorClass_, ArchTag_, ElementA_, ElementB_,
|
||
|
ElementC_, ElementAccumulator_>::kAlignmentB,
|
||
|
/// If true, kernel supports split-K with serial reduction
|
||
|
bool SplitKSerial = false,
|
||
|
/// Operation performed by GEMM
|
||
|
typename Operator_ = typename DefaultGemmConfiguration<
|
||
|
OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_,
|
||
|
ElementAccumulator_>::Operator,
|
||
|
/// Whether Beta is zero or not
|
||
|
bool IsBetaZero = false>
|
||
|
class B2bGemm {
|
||
|
public:
|
||
|
|
||
|
using ElementA = ElementA_;
|
||
|
using LayoutA = LayoutA_;
|
||
|
using TensorRefA = TensorRef<ElementA const, LayoutA>;
|
||
|
using ElementB = ElementB_;
|
||
|
using LayoutB = LayoutB_;
|
||
|
using TensorRefB = TensorRef<ElementB const, LayoutB>;
|
||
|
using ElementC = ElementC_;
|
||
|
using LayoutC = LayoutC_;
|
||
|
using TensorRefC = TensorRef<ElementC const, LayoutC>;
|
||
|
using TensorRefD = TensorRef<ElementC, LayoutC>;
|
||
|
using ElementAccumulator = ElementAccumulator_;
|
||
|
using OperatorClass = OperatorClass_;
|
||
|
using ArchTag = ArchTag_;
|
||
|
using ThreadblockShape0 = ThreadblockShape0_;
|
||
|
using ThreadblockShape1 = ThreadblockShape1_;
|
||
|
using WarpShape0 = WarpShape0_;
|
||
|
using WarpShape1 = WarpShape1_;
|
||
|
using InstructionShape = InstructionShape_;
|
||
|
using EpilogueOutputOp0 = EpilogueOutputOp0_;
|
||
|
using EpilogueOutputOp1 = EpilogueOutputOp1_;
|
||
|
using ThreadblockSwizzle = ThreadblockSwizzle_;
|
||
|
using Operator = Operator_;
|
||
|
static int const kStages = Stages;
|
||
|
static int const kAlignmentA = AlignmentA;
|
||
|
static int const kAlignmentB = AlignmentB;
|
||
|
static int const kAlignmentC = EpilogueOutputOp1::kCount;
|
||
|
static bool const kSplitKSerial = SplitKSerial;
|
||
|
static bool const kIsBetaZero = IsBetaZero;
|
||
|
static ComplexTransform const kTransformA = ComplexTransform::kNone;
|
||
|
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||
|
|
||
|
/// Define the kernel
|
||
|
using B2bGemmKernel = typename kernel::DefaultB2bGemm<
|
||
|
ElementA,
|
||
|
LayoutA,
|
||
|
kAlignmentA,
|
||
|
ElementB,
|
||
|
LayoutB,
|
||
|
kAlignmentB,
|
||
|
ElementC,
|
||
|
LayoutC,
|
||
|
ElementAccumulator,
|
||
|
OperatorClass,
|
||
|
ArchTag,
|
||
|
ThreadblockShape0,
|
||
|
ThreadblockShape1,
|
||
|
WarpShape0,
|
||
|
WarpShape1,
|
||
|
InstructionShape,
|
||
|
EpilogueOutputOp0,
|
||
|
EpilogueOutputOp1,
|
||
|
ThreadblockSwizzle,
|
||
|
kStages,
|
||
|
kSplitKSerial,
|
||
|
Operator,
|
||
|
kIsBetaZero
|
||
|
>::B2bGemmKernel;
|
||
|
|
||
|
/// Argument structure
|
||
|
struct Arguments {
|
||
|
|
||
|
//
|
||
|
// Data members
|
||
|
//
|
||
|
|
||
|
GemmCoord problem_size_0;
|
||
|
GemmCoord problem_size_1;
|
||
|
TensorRef<ElementA const, LayoutA> ref_A0;
|
||
|
TensorRef<ElementB const, LayoutB> ref_B0;
|
||
|
TensorRef<ElementC const, LayoutC> ref_C0;
|
||
|
TensorRef<ElementB const, LayoutB> ref_B1;
|
||
|
TensorRef<ElementC const, LayoutC> ref_C1;
|
||
|
TensorRef<ElementC, LayoutC> ref_D1;
|
||
|
typename EpilogueOutputOp0::Params epilogue0;
|
||
|
typename EpilogueOutputOp1::Params epilogue1;
|
||
|
int split_k_slices;
|
||
|
|
||
|
//
|
||
|
// Methods
|
||
|
//
|
||
|
|
||
|
/// Default ctor
|
||
|
CUTLASS_HOST_DEVICE
|
||
|
Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) {
|
||
|
|
||
|
}
|
||
|
|
||
|
/// Constructs an Arguments structure
|
||
|
CUTLASS_HOST_DEVICE
|
||
|
Arguments(
|
||
|
GemmCoord problem_size_0_,
|
||
|
GemmCoord problem_size_1_,
|
||
|
TensorRef<ElementA const, LayoutA> ref_A0_,
|
||
|
TensorRef<ElementB const, LayoutB> ref_B0_,
|
||
|
TensorRef<ElementC const, LayoutC> ref_C0_,
|
||
|
TensorRef<ElementB const, LayoutB> ref_B1_,
|
||
|
TensorRef<ElementC const, LayoutC> ref_C1_,
|
||
|
TensorRef<ElementC, LayoutC> ref_D1_,
|
||
|
typename EpilogueOutputOp0::Params epilogue0_ =
|
||
|
typename EpilogueOutputOp0::Params(),
|
||
|
typename EpilogueOutputOp1::Params epilogue1_ =
|
||
|
typename EpilogueOutputOp1::Params(),
|
||
|
int split_k_slices_ = 1
|
||
|
):
|
||
|
problem_size_0(problem_size_0_),
|
||
|
problem_size_1(problem_size_1_),
|
||
|
ref_A0(ref_A0_),
|
||
|
ref_B0(ref_B0_),
|
||
|
ref_C0(ref_C0_),
|
||
|
ref_B1(ref_B1_),
|
||
|
ref_C1(ref_C1_),
|
||
|
ref_D1(ref_D1_),
|
||
|
epilogue0(epilogue0_),
|
||
|
epilogue1(epilogue1_),
|
||
|
split_k_slices(split_k_slices_) {
|
||
|
|
||
|
}
|
||
|
};
|
||
|
|
||
|
private:
|
||
|
|
||
|
/// Kernel parameters object
|
||
|
typename B2bGemmKernel::Params params_;
|
||
|
|
||
|
public:
|
||
|
|
||
|
/// Constructs the GEMM.
|
||
|
B2bGemm() { }
|
||
|
|
||
|
/// Determines whether the GEMM can execute the given problem.
|
||
|
static Status can_implement(Arguments const &args) {
|
||
|
|
||
|
if (!kSplitKSerial && args.split_k_slices > 1) {
|
||
|
return Status::kErrorInvalidProblem;
|
||
|
}
|
||
|
|
||
|
Status status = B2bGemmKernel::can_implement(
|
||
|
args.problem_size_0,
|
||
|
args.problem_size_1,
|
||
|
args.ref_A0.non_const_ref(),
|
||
|
args.ref_B0.non_const_ref(),
|
||
|
args.ref_C0.non_const_ref(),
|
||
|
args.ref_B1.non_const_ref(),
|
||
|
args.ref_C1.non_const_ref(),
|
||
|
args.ref_D1
|
||
|
);
|
||
|
|
||
|
if (status != Status::kSuccess) {
|
||
|
return status;
|
||
|
}
|
||
|
|
||
|
return Status::kSuccess;
|
||
|
}
|
||
|
|
||
|
/// Gets the workspace size
|
||
|
static size_t get_workspace_size(Arguments const &args) {
|
||
|
|
||
|
size_t bytes = 0;
|
||
|
|
||
|
// Determine grid shape
|
||
|
ThreadblockSwizzle threadblock_swizzle;
|
||
|
|
||
|
cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape(
|
||
|
args.problem_size_0,
|
||
|
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||
|
args.split_k_slices);
|
||
|
|
||
|
if (kSplitKSerial && args.split_k_slices > 1) {
|
||
|
|
||
|
|
||
|
bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n());
|
||
|
}
|
||
|
|
||
|
return bytes;
|
||
|
}
|
||
|
|
||
|
/// Initializes GEMM state from arguments.
|
||
|
Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {
|
||
|
|
||
|
// Determine grid shape
|
||
|
ThreadblockSwizzle threadblock_swizzle;
|
||
|
|
||
|
cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
|
||
|
args.problem_size_0,
|
||
|
{ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK},
|
||
|
args.split_k_slices);
|
||
|
// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape(
|
||
|
// args.problem_size_1,
|
||
|
// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK},
|
||
|
// args.split_k_slices);
|
||
|
|
||
|
if (kSplitKSerial) {
|
||
|
if (args.split_k_slices > 1) {
|
||
|
if (!workspace) {
|
||
|
return Status::kErrorWorkspaceNull;
|
||
|
}
|
||
|
|
||
|
size_t bytes = get_workspace_size(args);
|
||
|
|
||
|
cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream);
|
||
|
|
||
|
if (result != cudaSuccess) {
|
||
|
return Status::kErrorInternal;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
else {
|
||
|
|
||
|
if (args.split_k_slices > 1) {
|
||
|
return Status::kErrorInvalidProblem;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Initialize the Params structure
|
||
|
params_ = typename B2bGemmKernel::Params{
|
||
|
args.problem_size_0,
|
||
|
args.problem_size_1,
|
||
|
grid_shape,
|
||
|
args.ref_A0.non_const_ref(),
|
||
|
args.ref_B0.non_const_ref(),
|
||
|
args.ref_C0.non_const_ref(),
|
||
|
args.ref_B1.non_const_ref(),
|
||
|
args.ref_C1.non_const_ref(),
|
||
|
args.ref_D1,
|
||
|
args.epilogue0,
|
||
|
args.epilogue1,
|
||
|
static_cast<int *>(workspace),
|
||
|
};
|
||
|
|
||
|
return Status::kSuccess;
|
||
|
}
|
||
|
|
||
|
/// Lightweight update given a subset of arguments
|
||
|
Status update(Arguments const &args, void *workspace = nullptr) {
|
||
|
|
||
|
if (kSplitKSerial && args.split_k_slices > 1) {
|
||
|
if (!workspace) {
|
||
|
return Status::kErrorWorkspaceNull;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
params_.ref_A0.reset(args.ref_A.non_const_ref().data());
|
||
|
params_.ref_B0.reset(args.ref_B.non_const_ref().data());
|
||
|
params_.ref_C0.reset(args.ref_C.non_const_ref().data());
|
||
|
params_.ref_B1.reset(args.ref_B.non_const_ref().data());
|
||
|
params_.ref_C1.reset(args.ref_C.non_const_ref().data());
|
||
|
params_.ref_D1.reset(args.ref_D.data());
|
||
|
params_.output_op_0 = args.epilogue0;
|
||
|
params_.output_op_1 = args.epilogue1;
|
||
|
params_.semaphore = static_cast<int *>(workspace);
|
||
|
|
||
|
return Status::kSuccess;
|
||
|
}
|
||
|
|
||
|
/// Runs the kernel using initialized state.
|
||
|
Status run(cudaStream_t stream = nullptr) {
|
||
|
|
||
|
ThreadblockSwizzle threadblock_swizzle;
|
||
|
|
||
|
dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);
|
||
|
dim3 block(B2bGemmKernel::kThreadCount, 1, 1);
|
||
|
|
||
|
cudaError_t result;
|
||
|
|
||
|
int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));
|
||
|
if (smem_size >= (48 << 10)) {
|
||
|
result = cudaFuncSetAttribute(Kernel<B2bGemmKernel>,
|
||
|
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||
|
smem_size);
|
||
|
|
||
|
if (result != cudaSuccess) {
|
||
|
return Status::kErrorInternal;
|
||
|
}
|
||
|
|
||
|
result = cudaFuncSetAttribute(
|
||
|
Kernel<B2bGemmKernel>,
|
||
|
cudaFuncAttributePreferredSharedMemoryCarveout, 100);
|
||
|
|
||
|
if (result != cudaSuccess) {
|
||
|
return Status::kErrorInternal;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
cutlass::Kernel<B2bGemmKernel><<<grid, block, smem_size, stream>>>(params_);
|
||
|
|
||
|
result = cudaGetLastError();
|
||
|
|
||
|
return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;
|
||
|
}
|
||
|
|
||
|
/// Runs the kernel using initialized state.
|
||
|
Status operator()(cudaStream_t stream = nullptr) {
|
||
|
return run(stream);
|
||
|
}
|
||
|
|
||
|
/// Runs the kernel using initialized state.
|
||
|
Status operator()(
|
||
|
Arguments const &args,
|
||
|
void *workspace = nullptr,
|
||
|
cudaStream_t stream = nullptr) {
|
||
|
|
||
|
Status status = initialize(args, workspace);
|
||
|
|
||
|
if (status == Status::kSuccess) {
|
||
|
status = run(stream);
|
||
|
}
|
||
|
|
||
|
return status;
|
||
|
}
|
||
|
};
|
||
|
|
||
|
} // namespace device
|
||
|
} // namespace gemm
|
||
|
} // namespace cutlass
|
||
|
|
||
|
////////////////////////////////////////////////////////////////////////////////
|