/*************************************************************************************************** * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright notice, this list of * conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used * to endorse or promote products derived from this software without specific prior written * permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Template for GEMM performing a reduction over K partitions in parallel. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/device_kernel.h" #include "cutlass/gemm/threadblock/threadblock_swizzle.h" #include "cutlass/gemm/kernel/gemm.h" #include "cutlass/gemm/kernel/default_gemm_splitk_parallel.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/epilogue/thread/conversion_op.h" #include "cutlass/reduction/kernel/reduce_split_k.h" #include "cutlass/reduction/thread/reduction_operators.h" //////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace device { //////////////////////////////////////////////////////////////////////////////// /*! Gemm device-level operator performing parallel reduction over the K partition. */ template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand typename LayoutA_, /// Element type for B matrix operand typename ElementB_, /// Layout type for B matrix operand typename LayoutB_, /// Element type for C and D matrix operands typename ElementC_, /// Layout type for C and D matrix operands typename LayoutC_, /// Element type for internal accumulation typename ElementAccumulator_ = ElementC_, /// Operator class tag typename OperatorClass_ = arch::OpClassSimt, /// Tag indicating architecture to tune for. This is the minimum SM that /// supports the intended feature. The device kernel can be built /// targeting any SM larger than this number. typename ArchTag_ = arch::Sm70, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::ThreadblockShape, /// Warp-level tile size (concept: GemmShape) typename WarpShape_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::WarpShape, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::InstructionShape, /// Epilogue output operator typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::EpilogueOutputOp, /// Epilogue output operator typename ConvertScaledOp_ = cutlass::epilogue::thread::Convert< ElementAccumulator_, DefaultGemmConfiguration::EpilogueOutputOp::kCount, ElementAccumulator_>, /// Reduction operator typename ReductionOp_ = cutlass::reduction::thread::ReduceAdd< ElementAccumulator_, typename EpilogueOutputOp_::ElementAccumulator, EpilogueOutputOp_::kCount>, /// Threadblock-level swizzling operator typename ThreadblockSwizzle_ = threadblock::GemmSplitKHorizontalThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration::kStages, /// Access granularity of A matrix in units of elements int kAlignmentA = DefaultGemmConfiguration::kAlignmentA, /// Access granularity of B matrix in units of elements int kAlignmentB = DefaultGemmConfiguration::kAlignmentB, /// Operation performed by GEMM typename Operator_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::Operator> class GemmSplitKParallel { public: using ElementA = ElementA_; using LayoutA = LayoutA_; using ElementB = ElementB_; using LayoutB = LayoutB_; using ElementC = ElementC_; using LayoutC = LayoutC_; using ElementAccumulator = ElementAccumulator_; using OperatorClass = OperatorClass_; using ArchTag = ArchTag_; using ThreadblockShape = ThreadblockShape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; using ConvertScaledOp = ConvertScaledOp_; using EpilogueOutputOp = EpilogueOutputOp_; using ReductionOp = ReductionOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; using Operator = Operator_; static int const kStages = Stages; /// GEMM kernel using GemmKernel = typename kernel::DefaultGemmSplitKParallel< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, ConvertScaledOp, ThreadblockSwizzle, kStages, Operator >::GemmKernel; /// Reduction kernel using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, EpilogueOutputOp, ReductionOp >; // // // /// Argument structure struct Arguments { // // Data members // GemmCoord problem_size; TensorRef ref_A; TensorRef ref_B; TensorRef ref_C; TensorRef ref_D; typename EpilogueOutputOp::Params epilogue; int split_k_slices; typename ConvertScaledOp::Params convert; typename ReductionOp::Params reduction; // // Methods // /// Default ctor CUTLASS_HOST_DEVICE Arguments() { } /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( GemmCoord problem_size_, TensorRef ref_A_, TensorRef ref_B_, TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = typename EpilogueOutputOp::Params(), int split_k_slices = 1, typename ConvertScaledOp::Params convert_ = typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_ = typename ReductionOp::Params() ): problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), ref_C(ref_C_), ref_D(ref_D_), epilogue(epilogue_), split_k_slices(split_k_slices), convert(convert_), reduction(reduction_) { } }; private: /// Kernel parameters object typename GemmKernel::Params gemm_params_; /// Reduction kernel parameters object typename ReductionKernel::Params reduction_params_; public: /// Constructs the GEMM. GemmSplitKParallel() { } /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { // TODO return Status::kSuccess; } /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); return sizeof(ElementAccumulator_) * size_t(args.problem_size.m()) * size_t(args.problem_size.n()) * grid_shape.k(); } /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); // Define a reference to the workspace - this is an aligned region in device memory. if (!workspace) { return Status::kErrorWorkspaceNull; } TensorRef ref_workspace( static_cast(workspace), args.problem_size.n()); int64_t partition_stride = int64_t(args.problem_size.m()) * int64_t(args.problem_size.n()); // Initialize the Params structure gemm_params_ = typename GemmKernel::Params{ args.problem_size, grid_shape, args.ref_A.non_const_ref(), args.ref_B.non_const_ref(), ref_workspace, args.convert, partition_stride }; reduction_params_ = typename ReductionKernel::Params( args.problem_size.mn(), grid_shape.k(), partition_stride, ref_workspace, args.ref_D, args.ref_C.non_const_ref(), args.epilogue ); return Status::kSuccess; } /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { if (!workspace) { return Status::kErrorWorkspaceNull; } gemm_params_.ref_A.reset(args.ref_A.data()); gemm_params_.ref_B.reset(args.ref_B.data()); gemm_params_.ref_D.reset(workspace); reduction_params_.ref_D.reset(args.ref_D.data()); reduction_params_.ref_C.reset(args.ref_C.data()); return Status::kSuccess; } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { // // Launch GEMM kernel // ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape(gemm_params_.grid_tiled_shape); dim3 block(GemmKernel::kThreadCount, 1, 1); cudaError_t result; int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); if (smem_size >= (48 << 10)) { result = cudaFuncSetAttribute( Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (result != cudaSuccess) { return Status::kErrorInternal; } } Kernel<<>>(gemm_params_); result = cudaGetLastError(); if (result != cudaSuccess) { return Status::kErrorInternal; } // // Launch reduction kernel // block = ReductionKernel::block_shape(); grid = ReductionKernel::grid_shape(gemm_params_.problem_size.mn()); Kernel<<< grid, block, 0, stream >>>(reduction_params_); result = cudaGetLastError(); if (result != cudaSuccess) { return Status::kErrorInternal; } return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; } /// Runs the kernel using initialized state. Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace); if (status == Status::kSuccess) { status = run(stream); } return status; } }; //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for column-major output template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand typename LayoutA_, /// Element type for B matrix operand typename ElementB_, /// Layout type for B matrix operand typename LayoutB_, /// Element type for C and D matrix operands typename ElementC_, /// Element type for internal accumulation typename ElementAccumulator_, /// Operator class tag typename OperatorClass_, /// Tag indicating architecture to tune for. This is the minimum SM that /// supports the intended feature. The device kernel can be built /// targeting any SM larger than this number. typename ArchTag_, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape_, /// Warp-level tile size (concept: GemmShape) typename WarpShape_, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape_, /// Epilogue output operator typename EpilogueOutputOp_, /// Epilogue output operator typename ConvertScaledOp_, /// Reduction operator typename ReductionOp_, /// Threadblock-level swizzling operator typename ThreadblockSwizzle_, /// Number of stages used in the pipelined mainloop int Stages, int kAlignmentA, int kAlignmentB, /// Operation performed by GEMM typename Operator_> class GemmSplitKParallel { public: using ElementA = ElementA_; using LayoutA = LayoutA_; using ElementB = ElementB_; using LayoutB = LayoutB_; using ElementC = ElementC_; using LayoutC = layout::ColumnMajor; using ElementAccumulator = ElementAccumulator_; using OperatorClass = OperatorClass_; using ArchTag = ArchTag_; using ThreadblockShape = ThreadblockShape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; using ConvertScaledOp = ConvertScaledOp_; using EpilogueOutputOp = EpilogueOutputOp_; using ReductionOp = ReductionOp_; using ThreadblockSwizzle = ThreadblockSwizzle_; using Operator = Operator_; static int const kStages = Stages; using UnderlyingOperator = GemmSplitKParallel< ElementB, typename layout::LayoutTranspose::type, ElementA, typename layout::LayoutTranspose::type, ElementC, layout::RowMajor, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, ConvertScaledOp, ReductionOp, ThreadblockSwizzle, Stages, kAlignmentA, kAlignmentB, Operator >; using UnderlyingArguments = typename UnderlyingOperator::Arguments; using GemmKernel = typename UnderlyingOperator::GemmKernel; using ReductionKernel = typename UnderlyingOperator::ReductionKernel; /// Argument structure struct Arguments { // // Data members // GemmCoord problem_size; TensorRef ref_A; TensorRef ref_B; TensorRef ref_C; TensorRef ref_D; typename EpilogueOutputOp::Params epilogue; int split_k_slices; typename ConvertScaledOp::Params convert; typename ReductionOp::Params reduction; // // Methods // /// Default ctor CUTLASS_HOST_DEVICE Arguments() { } /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( GemmCoord problem_size_, TensorRef ref_A_, TensorRef ref_B_, TensorRef ref_C_, TensorRef ref_D_, typename EpilogueOutputOp::Params epilogue_ = typename EpilogueOutputOp::Params(), int split_k_slices = 1, typename ConvertScaledOp::Params convert_ = typename ConvertScaledOp::Params(), typename ReductionOp::Params reduction_ = typename ReductionOp::Params() ): problem_size(problem_size_), ref_A(ref_A_), ref_B(ref_B_), ref_C(ref_C_), ref_D(ref_D_), epilogue(epilogue_), split_k_slices(split_k_slices), convert(convert_), reduction(reduction_) { } }; private: /// Kernel parameters object UnderlyingOperator underlying_operator_; public: /// Constructs the GEMM. GemmSplitKParallel() { } /// Helper to construct a transposed equivalent for the underying GEMM operator static UnderlyingArguments to_underlying_arguments(Arguments const &args) { return UnderlyingArguments( {args.problem_size.n(), args.problem_size.m(), args.problem_size.k()}, {args.ref_B.data(), args.ref_B.stride(0)}, {args.ref_A.data(), args.ref_A.stride(0)}, {args.ref_C.data(), args.ref_C.stride(0)}, {args.ref_D.data(), args.ref_D.stride(0)}, args.epilogue, args.split_k_slices, args.convert, args.reduction ); } /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { return UnderlyingOperator::can_implement(to_underlying_arguments(args)); } /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { return UnderlyingOperator::get_workspace_size(to_underlying_arguments(args)); } /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace) { return underlying_operator_.initialize(to_underlying_arguments(args), workspace); } /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { return underlying_operator_.update(to_underlying_arguments(args), workspace); } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { return underlying_operator_.run(stream); } /// Runs the kernel using initialized state. Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); } return status; } }; //////////////////////////////////////////////////////////////////////////////// } // namespace device } // namespace gemm } // namespace cutlass ////////////////////////////////////////////////////////////////////////////////