/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Performs a dual gemm in one fused kernel: ``` D0 = epilogue0(X @ B0, C0) D1 = epilogue1(X @ B1, C1) D2 = element_wise(D0, D1) ``` */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" #include "cutlass/device_kernel.h" #include "cutlass/gemm/threadblock/threadblock_swizzle.h" #include "cutlass/gemm/device/default_gemm_configuration.h" #include "cutlass/gemm/threadblock/default_mma.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" #include "../kernel/dual_gemm.h" #include "../dual_gemm_common.h" //////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace device { ///////////////////////////////////////////////////////////////////////////////////////////////// template < /// Element type for A matrix operand typename ElementA_, /// Layout type for A matrix operand typename LayoutA_, /// Element type for B matrix operand typename ElementB_, /// Layout type for B0 matrix operand typename LayoutB0_, /// Layout type for B1 matrix operand typename LayoutB1_, /// Element type for C and D matrix operands typename ElementC_, /// Layout type for C and D matrix operands typename LayoutC_, /// Element type for internal accumulation typename ElementAccumulator_, /// Operator class tag typename OperatorClass_, /// Tag indicating architecture to tune for typename ArchTag_, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape_, /// Warp-level tile size (concept: GemmShape) typename WarpShape_, /// Instruction-level tile size (concept: GemmShape) typename InstructionShape_, /// Epilogue output operator typename EpilogueOutputOp0_, typename EpilogueOutputOp1_, typename EpilogueOutputOp2_, /// Threadblock-level swizzling operator typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, /// Number of stages used in the pipelined mainloop int Stages = DefaultGemmConfiguration::kStages, bool StoreD0 = true, bool StoreD1 = true, /// If true, kernel supports split-K with serial reduction bool SplitKSerial = false, /// Access granularity of A matrix in units of elements int AlignmentA = DefaultGemmConfiguration::kAlignmentA, /// Access granularity of B matrix in units of elements int AlignmentB = DefaultGemmConfiguration::kAlignmentB, /// Operation performed by GEMM typename Operator_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, ElementAccumulator_>::Operator> class DualGemm { public: using ElementA = ElementA_; using LayoutA = LayoutA_; using TensorRefA = TensorRef; using ElementB = ElementB_; using LayoutB0 = LayoutB0_; using LayoutB1 = LayoutB1_; using TensorRefB0 = TensorRef; using TensorRefB1 = TensorRef; using ElementC = ElementC_; using LayoutC = LayoutC_; using TensorRefC = TensorRef; using TensorRefD = TensorRef; using ElementAccumulator = ElementAccumulator_; using OperatorClass = OperatorClass_; using ArchTag = ArchTag_; using ThreadblockShape = ThreadblockShape_; using WarpShape = WarpShape_; using InstructionShape = InstructionShape_; using EpilogueOutputOp0 = EpilogueOutputOp0_; using EpilogueOutputOp1 = EpilogueOutputOp1_; using EpilogueOutputOp2 = EpilogueOutputOp2_; using ThreadblockSwizzle = ThreadblockSwizzle_; using Operator = Operator_; static int const kStages = Stages; static int const kAlignmentA = AlignmentA; static int const kAlignmentB = AlignmentB; static int const kAlignmentC = EpilogueOutputOp1::kCount; static bool const kSplitKSerial = SplitKSerial; static bool constexpr kStoreD0 = StoreD0; static bool constexpr kStoreD1 = StoreD1; static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; using LayoutScaleBias = layout::RowMajor; /// Define the kernel /// Define the threadblock-scoped matrix multiply-accumulate static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); static_assert(kStages >= 3, "Only multistage is implemented"); using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator>::ThreadblockMma; using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, Stages, Operator>::ThreadblockMma; using DualMma = threadblock::DualMmaMultistage< typename Mma0::Shape, typename Mma0::IteratorA, typename Mma0::SmemIteratorA, Mma0::kCacheOpA, typename Mma0::IteratorB, typename Mma0::SmemIteratorB, Mma0::kCacheOpB, typename Mma1::IteratorB, typename Mma1::SmemIteratorB, typename Mma0::ElementC, typename Mma0::LayoutC, typename Mma0::Policy, typename Mma1::Policy, Mma0::kStages, SharedMemoryClearOption::kNone >; static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; /// Define the epilogue using Epilogue0 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, EpilogueOutputOp0::kCount>::Epilogue; using Epilogue1 = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, EpilogueOutputOp1::kCount>::Epilogue; /// Define the kernel-level GEMM operator. using DualGemmKernel = kernel::DualGemm< DualMma, Epilogue0, Epilogue1, EpilogueOutputOp2, ThreadblockSwizzle, kSplitKSerial, kStoreD0, kStoreD1>; /// Argument structure struct Arguments { // // Data members // DualGemmMode mode; GemmCoord problem_size; TensorRef ref_A0; TensorRef ref_B0; TensorRef ref_C0; TensorRef ref_D0; TensorRef ref_B1; TensorRef ref_C1; TensorRef ref_D1; TensorRef ref_D2; typename EpilogueOutputOp0::Params epilogue0; typename EpilogueOutputOp1::Params epilogue1; typename EpilogueOutputOp2::Params epilogue2; int split_k_slices; int batch_count; int64_t batch_stride_A; int64_t batch_stride_B0; int64_t batch_stride_B1; int64_t batch_stride_C; int64_t batch_stride_D; // // Methods // /// Default ctor CUTLASS_HOST_DEVICE Arguments(): problem_size(0, 0, 0), split_k_slices(1) { } /// Constructs an Arguments structure CUTLASS_HOST_DEVICE Arguments( DualGemmMode mode, GemmCoord problem_size_, TensorRef ref_A0_, TensorRef ref_B0_, TensorRef ref_C0_, TensorRef ref_D0_, TensorRef ref_B1_, TensorRef ref_C1_, TensorRef ref_D1_, TensorRef ref_D2_, typename EpilogueOutputOp0::Params epilogue0_ = typename EpilogueOutputOp0::Params(), typename EpilogueOutputOp1::Params epilogue1_ = typename EpilogueOutputOp1::Params(), typename EpilogueOutputOp2::Params epilogue2_ = typename EpilogueOutputOp2::Params(), int split_k_slices_ = 1, int batch_count = 1, int64_t batch_stride_A = 0, int64_t batch_stride_B0 = 0, int64_t batch_stride_B1 = 0, int64_t batch_stride_C = 0, int64_t batch_stride_D = 0 ): mode(mode), problem_size(problem_size_), ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), ref_D0(ref_D0_), ref_B1(ref_B1_), ref_C1(ref_C1_), ref_D1(ref_D1_), ref_D2(ref_D2_), epilogue0(epilogue0_), epilogue1(epilogue1_), epilogue2(epilogue2_), split_k_slices(split_k_slices_), batch_count(batch_count), batch_stride_A(batch_stride_A), batch_stride_B0(batch_stride_B0), batch_stride_B1(batch_stride_B1), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D) { } }; private: /// Kernel parameters object typename DualGemmKernel::Params params_; public: /// Constructs the GEMM. DualGemm() = default; /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { return Status::kErrorInvalidProblem; } if (!kSplitKSerial && args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } if (kStoreD0 != (args.ref_D0.data() != nullptr)) { return Status::kErrorInternal; } if (kStoreD1 != (args.ref_D1.data() != nullptr)) { return Status::kErrorInternal; } Status status = DualGemmKernel::can_implement( args.problem_size, args.ref_A0.non_const_ref(), args.ref_B0.non_const_ref(), args.ref_C0.non_const_ref(), args.ref_D0, args.ref_B1.non_const_ref(), args.ref_C1.non_const_ref(), args.ref_D1, args.ref_D2 ); if (status != Status::kSuccess) { return status; } return Status::kSuccess; } /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { size_t bytes = 0; if (kSplitKSerial && args.split_k_slices > 1) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); } return bytes; } /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); if (kSplitKSerial) { if (args.split_k_slices > 1) { if (!workspace) { return Status::kErrorWorkspaceNull; } size_t bytes = get_workspace_size(args); cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); if (result != cudaSuccess) { return Status::kErrorInternal; } } } else { if (args.split_k_slices > 1) { return Status::kErrorInvalidProblem; } } // Initialize the Params structure params_ = typename DualGemmKernel::Params{ args.mode, args.problem_size, grid_shape, args.ref_A0.non_const_ref(), args.ref_B0.non_const_ref(), args.ref_C0.non_const_ref(), args.ref_D0, args.ref_B1.non_const_ref(), args.ref_C1.non_const_ref(), args.ref_D1, args.ref_D2, args.epilogue0, args.epilogue1, args.epilogue2, reinterpret_cast(workspace), args.batch_stride_A, args.batch_stride_B0, args.batch_stride_B1, args.batch_stride_C, args.batch_stride_D, }; return Status::kSuccess; } /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { if (kSplitKSerial && args.split_k_slices > 1) { if (!workspace) { return Status::kErrorWorkspaceNull; } } params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); params_.ref_D0.reset(args.ref_D0.data()); params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); params_.ref_D1.reset(args.ref_D1.data()); params_.ref_D2.reset(args.ref_D2.data()); params_.output_op_0 = args.epilogue0; params_.output_op_1 = args.epilogue1; params_.output_op_2 = args.epilogue2; params_.semaphore = reinterpret_cast(workspace); return Status::kSuccess; } /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { ThreadblockSwizzle threadblock_swizzle; dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); dim3 block(DualGemmKernel::kThreadCount, 1, 1); cudaError_t result; int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); if (smem_size >= (48 << 10)) { result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); if (result != cudaSuccess) { return Status::kErrorInternal; } } cutlass::Kernel<<>>(params_); result = cudaGetLastError(); return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; } /// Runs the kernel using initialized state. Status operator()(cudaStream_t stream = nullptr) { return run(stream); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); } return status; } }; } // namespace device } // namespace gemm } // namespace cutlass ////////////////////////////////////////////////////////////////////////////////