/*************************************************************************************************** * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holdvr nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Defines the FusedMultiHeadAttention Class The class contains the following: 1) GEMM0 with epilogue fusion, 2) GEMM1 with mainloop fusion, and 3) A lightweight full softmax reduction kernel. */ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include "cutlass/cutlass.h" #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h" #include "cutlass/reduction/kernel/reduce_softmax_final.h" #include "gemm_grouped_with_softmax_visitor.h" namespace cutlass { template < typename ElementQ_, typename LayoutQ_, typename ElementK_, typename LayoutK_, typename ElementP_, typename LayoutP_, typename ElementCompute_, typename OperatorClass_, typename ArchTag_, typename ThreadblockShape0_, typename ThreadblockShape1_, typename WarpShape0_, typename WarpShape1_, typename InstructionShape_, int kStages0_, int kStages1_, bool UseMasking_ = false, cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute, int Alignment = 128 / cutlass::sizeof_bits::value, typename ElementSoftmax_ = ElementP_ > class FusedMultiHeadAttention { public: using ElementQ = ElementQ_; using ElementK = ElementK_; using ElementP = ElementP_; using ElementV = ElementK; using ElementOutput = ElementP; using ElementAccumulator = ElementCompute_; using LayoutQ = LayoutQ_; using LayoutK = LayoutK_; using LayoutP = LayoutP_; using LayoutV = LayoutK; using LayoutO = LayoutP; using ElementNorm = cutlass::half_t; using ElementSum = cutlass::half_t; using ElementSoftmaxCompute = float; using LayoutNorm = cutlass::layout::RowMajor; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; using OperatorClass = OperatorClass_; using ArchTag = ArchTag_; using ThreadblockShape0 = ThreadblockShape0_; using WarpShape0 = WarpShape0_; using ThreadblockShape1 = ThreadblockShape1_; using WarpShape1 = WarpShape1_; static int const Stages0 = kStages0_; static int const Stages1 = kStages1_; using InstructionShape = InstructionShape_; using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>; using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration< OperatorClass, ArchTag, ElementQ, ElementK, ElementP, ElementAccumulator>::Operator; static bool const kInternalTranspose = cutlass::platform::is_same::value; static bool const kUseMasking = UseMasking_; static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_; static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_; using MapArguments = cutlass::gemm::kernel::detail::MapArguments< ElementQ, LayoutQ, cutlass::ComplexTransform::kNone, 8, ElementK, LayoutK, cutlass::ComplexTransform::kNone, 8, LayoutP, kInternalTranspose >; using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< typename MapArguments::ElementA, typename MapArguments::LayoutA, MapArguments::kAlignmentA, typename MapArguments::ElementB, typename MapArguments::LayoutB, MapArguments::kAlignmentB, ElementP, typename MapArguments::LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape0, WarpShape0, InstructionShape, EpilogueOutputOp0, ThreadblockSwizzle, Stages0, true, Operator, cutlass::gemm::SharedMemoryClearOption::kNone >::GemmKernel; using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< ThreadblockShape0, DefaultGemmKernel::kThreadCount, typename DefaultGemmKernel::Epilogue::OutputTileIterator, typename EpilogueOutputOp0::ElementCompute, ElementNorm, ElementSum, ElementSoftmaxCompute, EpilogueOutputOp0, kUseMasking >; using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< EpilogueVisitor, typename DefaultGemmKernel::Epilogue >::Epilogue; using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor< typename DefaultGemmKernel::Mma, Epilogue, ThreadblockSwizzle, kGroupScheduleMode0, kInternalTranspose, kUseMasking >; using GemmGrouped0 = cutlass::gemm::device::GemmGrouped; using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< ElementNorm, ElementSum, typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute, typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape, true >; using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion< ElementP, LayoutP, cutlass::ComplexTransform::kNone, 128 / cutlass::sizeof_bits::value, ElementV, LayoutV, cutlass::ComplexTransform::kNone, 128 / cutlass::sizeof_bits::value, ElementNorm, LayoutNorm, ElementOutput, LayoutO, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape1, WarpShape1, InstructionShape, EpilogueOutputOp1, ThreadblockSwizzle, Stages1, kGroupScheduleMode1 >::GemmKernel; using GemmGrouped1 = cutlass::gemm::device::GemmGrouped; public: /// Arguments class struct Arguments { cutlass::gemm::GemmCoord *problem_sizes0; cutlass::gemm::GemmCoord *problem_sizes0_real; cutlass::gemm::GemmCoord *problem_sizes1; int problem_count; int threadblock_count; ElementQ ** ptr_Q; ElementK ** ptr_K; ElementP ** ptr_P; ElementP ** ptr_V; ElementP ** ptr_O; ElementNorm **ptr_Max; ElementSum **ptr_Sum; ElementP *block_P; ElementNorm *block_Norm; ElementSum *block_Sum; int64_t *offset_P; int64_t *offset_Norm_Device; int64_t *offset_Sum_Device; typename LayoutQ::Stride::LongIndex *ldq; typename LayoutK::Stride::LongIndex *ldk; typename LayoutP::Stride::LongIndex *ldp; typename LayoutP::Stride::LongIndex *ldv; typename LayoutP::Stride::LongIndex *ldo; cutlass::gemm::GemmCoord *problem_sizes0_host; cutlass::gemm::GemmCoord *problem_sizes1_host; ElementAccumulator alpha0; ElementAccumulator alpha1; ElementAccumulator beta; int head_number; int batch_size; int seq_length; typename ApplyFinalReductionDevice::Arguments reduction; // // Methods // Arguments(): problem_count(0), threadblock_count(0), ptr_Q(nullptr), ptr_K(nullptr), ptr_P(nullptr), ptr_V(nullptr), ptr_O(nullptr), ptr_Max(nullptr), ptr_Sum(nullptr), block_P(nullptr), block_Norm(nullptr), block_Sum(nullptr), offset_P(nullptr), offset_Norm_Device(nullptr), offset_Sum_Device(nullptr), ldq(nullptr), ldk(nullptr), ldp(nullptr), ldv(nullptr), ldo(nullptr), head_number(0), batch_size(0), seq_length(0) { } Arguments( cutlass::gemm::GemmCoord *problem_sizes0, cutlass::gemm::GemmCoord *problem_sizes1, int problem_count, int threadblock_count, ElementQ ** ptr_Q, ElementK ** ptr_K, ElementP ** ptr_P, ElementP ** ptr_V, ElementP ** ptr_O, ElementNorm **ptr_Max, ElementSum **ptr_Sum, ElementP *block_P, ElementNorm *block_Norm, ElementSum *block_Sum, int64_t *offset_P, int64_t *offset_Norm_Device, int64_t *offset_Sum_Device, typename LayoutQ::Stride::LongIndex *ldq, typename LayoutK::Stride::LongIndex *ldk, typename LayoutP::Stride::LongIndex *ldp, typename LayoutP::Stride::LongIndex *ldv, typename LayoutP::Stride::LongIndex *ldo, ElementAccumulator alpha0, ElementAccumulator alpha1, ElementAccumulator beta, int head_number, int batch_size, int seq_length, cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr, cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr, cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr ): problem_sizes0(problem_sizes0), problem_sizes1(problem_sizes1), problem_count(problem_count), threadblock_count(threadblock_count), ptr_Q(ptr_Q), ptr_K(ptr_K), ptr_P(ptr_P), ptr_V(ptr_V), ptr_O(ptr_O), ptr_Max(ptr_Max), ptr_Sum(ptr_Sum), block_P(block_P), block_Norm(block_Norm), block_Sum(block_Sum), offset_P(offset_P), offset_Norm_Device(offset_Norm_Device), offset_Sum_Device(offset_Sum_Device), ldq(ldq), ldk(ldk), ldp(ldp), ldv(ldv), ldo(ldo), alpha0(alpha0), alpha1(alpha1), beta(beta), head_number(head_number), batch_size(batch_size), seq_length(seq_length), problem_sizes0_host(problem_sizes0_host), problem_sizes1_host(problem_sizes1_host), problem_sizes0_real(problem_sizes0_real), reduction( problem_sizes0, block_Norm, block_Sum, offset_Norm_Device, offset_Sum_Device ) { } }; struct Params { cutlass::gemm::GemmCoord *problem_sizes0; cutlass::gemm::GemmCoord *problem_sizes0_real; cutlass::gemm::GemmCoord *problem_sizes1; int problem_count; int threadblock_count; ElementQ ** ptr_Q; ElementK ** ptr_K; ElementP ** ptr_P; ElementP ** ptr_V; ElementP ** ptr_O; ElementNorm **ptr_Max; ElementSum **ptr_Sum; ElementP *block_P; ElementNorm *block_Norm; ElementSum *block_Sum; int64_t *offset_P; int64_t *offset_Norm_Device; int64_t *offset_Sum_Device; typename LayoutQ::Stride::LongIndex *ldq; typename LayoutK::Stride::LongIndex *ldk; typename LayoutP::Stride::LongIndex *ldp; typename LayoutP::Stride::LongIndex *ldv; typename LayoutP::Stride::LongIndex *ldo; cutlass::gemm::GemmCoord *problem_sizes0_host; cutlass::gemm::GemmCoord *problem_sizes1_host; ElementAccumulator alpha0; ElementAccumulator alpha1; ElementAccumulator beta; int head_number; int batch_size; int seq_length; typename ApplyFinalReductionDevice::Params reduction; Params(): problem_count(0), threadblock_count(0), ptr_Q(nullptr), ptr_K(nullptr), ptr_P(nullptr), ptr_V(nullptr), ptr_O(nullptr), ptr_Max(nullptr), ptr_Sum(nullptr), block_P(nullptr), block_Norm(nullptr), block_Sum(nullptr), offset_P(nullptr), offset_Norm_Device(nullptr), offset_Sum_Device(nullptr), ldq(nullptr), ldk(nullptr), ldp(nullptr), ldv(nullptr), ldo(nullptr), problem_sizes0(nullptr), problem_sizes1(nullptr), problem_sizes0_real(nullptr), head_number(0), batch_size(0), seq_length(0) { } Params(Arguments const &args, void *workspace = nullptr): problem_sizes0(args.problem_sizes0), problem_sizes1(args.problem_sizes1), problem_count(args.problem_count), threadblock_count(args.threadblock_count), ptr_Q(args.ptr_Q), ptr_K(args.ptr_K), ptr_P(args.ptr_P), ptr_V(args.ptr_V), ptr_O(args.ptr_O), ptr_Max(args.ptr_Max), ptr_Sum(args.ptr_Sum), block_P(args.block_P), block_Norm(args.block_Norm), block_Sum(args.block_Sum), offset_P(args.offset_P), offset_Norm_Device(args.offset_Norm_Device), offset_Sum_Device(args.offset_Sum_Device), ldq(args.ldq), ldk(args.ldk), ldp(args.ldp), ldv(args.ldv), ldo(args.ldo), problem_sizes0_host(args.problem_sizes0_host), problem_sizes1_host(args.problem_sizes1_host), problem_sizes0_real(args.problem_sizes0_real), alpha0(args.alpha0), alpha1(args.alpha1), beta(args.beta), head_number(args.head_number), batch_size(args.batch_size), seq_length(args.seq_length), reduction(args.reduction) { } }; private: Params params_; GemmGrouped0 gemm_grouped0; GemmGrouped1 gemm_grouped1; public: /// Ctor FusedMultiHeadAttention() { } /// Initialize Status initialize(Arguments const &args, void *workspace0 = nullptr, void *workspace1 = nullptr) { params_ = Params(args); typename GemmGrouped0::Arguments args_gemm0( params_.problem_sizes0, params_.problem_count, params_.threadblock_count, params_.ptr_Q, params_.ptr_K, params_.ptr_P, params_.ptr_P, params_.ptr_Max, params_.ptr_Sum, params_.ldq, params_.ldk, params_.ldp, params_.ldp, typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments( { params_.alpha0, params_.beta } ), params_.problem_sizes0_host, params_.problem_sizes0_real ); Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0); typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta); typename GemmGrouped1::Arguments args_gemm1( params_.problem_sizes1, params_.problem_count, params_.threadblock_count, epilogue_op1, params_.ptr_P, params_.ptr_V, params_.ptr_O, params_.ptr_O, (void**)params_.ptr_Max, (void**)params_.ptr_Sum, params_.ldp, params_.ldv, params_.ldo, params_.ldo, params_.problem_sizes1_host ); Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1); if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) { return cutlass::Status::kSuccess; }else{ if (result0 != cutlass::Status::kSuccess) { return result0; }else{ return result1; } } } /// Run Status run(cudaStream_t stream = nullptr) { Status result = gemm_grouped0.run(); cudaError_t error_info; if (result != cutlass::Status::kSuccess) { return cutlass::Status::kErrorInternal; } int thread_per_block = 1024; dim3 final_reduction_grid(params_.head_number * params_.batch_size); dim3 final_reduction_block(thread_per_block); cutlass::Kernel<<< final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream >>>(params_.reduction); error_info = cudaGetLastError(); if (error_info != cudaSuccess) { return cutlass::Status::kErrorInternal; } result = gemm_grouped1.run(); if (result != cutlass::Status::kSuccess) { return cutlass::Status::kErrorInternal; } return cutlass::Status::kSuccess; } /// Function call operator Status operator()(cudaStream_t stream = nullptr) { return run(stream); } }; }