cutlass/examples/41_multi_head_attention/gemm_attention.h
ANIKET SHIVAM b72cbf957d
CUTLASS 2.10 (#615)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2022-09-03 18:48:46 -04:00

627 lines
17 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Defines the FusedMultiHeadAttention Class
The class contains the following:
1) GEMM0 with epilogue fusion,
2) GEMM1 with mainloop fusion, and
3) A lightweight full softmax reduction kernel.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include <cmath>
#include <iostream>
#include <vector>
#include <limits>
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/gemm/kernel/default_gemm_grouped_softmax_mainloop_fusion.h"
#include "cutlass/reduction/kernel/reduce_softmax_final.h"
#include "gemm_grouped_with_softmax_visitor.h"
namespace cutlass {
template <
typename ElementQ_,
typename LayoutQ_,
typename ElementK_,
typename LayoutK_,
typename ElementP_,
typename LayoutP_,
typename ElementCompute_,
typename OperatorClass_,
typename ArchTag_,
typename ThreadblockShape0_,
typename ThreadblockShape1_,
typename WarpShape0_,
typename WarpShape1_,
typename InstructionShape_,
int kStages0_,
int kStages1_,
bool UseMasking_ = false,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode0_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode1_ = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute,
int Alignment = 128 / cutlass::sizeof_bits<ElementQ_>::value,
typename ElementSoftmax_ = ElementP_
>
class FusedMultiHeadAttention {
public:
using ElementQ = ElementQ_;
using ElementK = ElementK_;
using ElementP = ElementP_;
using ElementV = ElementK;
using ElementOutput = ElementP;
using ElementAccumulator = ElementCompute_;
using LayoutQ = LayoutQ_;
using LayoutK = LayoutK_;
using LayoutP = LayoutP_;
using LayoutV = LayoutK;
using LayoutO = LayoutP;
using ElementNorm = cutlass::half_t;
using ElementSum = cutlass::half_t;
using ElementSoftmaxCompute = float;
using LayoutNorm = cutlass::layout::RowMajor;
using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle;
using OperatorClass = OperatorClass_;
using ArchTag = ArchTag_;
using ThreadblockShape0 = ThreadblockShape0_;
using WarpShape0 = WarpShape0_;
using ThreadblockShape1 = ThreadblockShape1_;
using WarpShape1 = WarpShape1_;
static int const Stages0 = kStages0_;
static int const Stages1 = kStages1_;
using InstructionShape = InstructionShape_;
using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>;
using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination<
ElementOutput, 128 / cutlass::sizeof_bits<ElementOutput>::value,
ElementAccumulator, ElementAccumulator, cutlass::epilogue::thread::ScaleType::Nothing>;
using Operator = typename cutlass::gemm::device::DefaultGemmConfiguration<
OperatorClass, ArchTag, ElementQ, ElementK, ElementP,
ElementAccumulator>::Operator;
static bool const kInternalTranspose = cutlass::platform::is_same<LayoutP, cutlass::layout::ColumnMajor>::value;
static bool const kUseMasking = UseMasking_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode0 = GroupScheduleMode0_;
static cutlass::gemm::kernel::GroupScheduleMode const kGroupScheduleMode1 = GroupScheduleMode1_;
using MapArguments = cutlass::gemm::kernel::detail::MapArguments<
ElementQ,
LayoutQ,
cutlass::ComplexTransform::kNone,
8,
ElementK,
LayoutK,
cutlass::ComplexTransform::kNone,
8,
LayoutP,
kInternalTranspose
>;
using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm<
typename MapArguments::ElementA,
typename MapArguments::LayoutA,
MapArguments::kAlignmentA,
typename MapArguments::ElementB,
typename MapArguments::LayoutB,
MapArguments::kAlignmentB,
ElementP,
typename MapArguments::LayoutC,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape0,
WarpShape0,
InstructionShape,
EpilogueOutputOp0,
ThreadblockSwizzle,
Stages0,
true,
Operator,
cutlass::gemm::SharedMemoryClearOption::kNone
>::GemmKernel;
using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax<
ThreadblockShape0,
DefaultGemmKernel::kThreadCount,
typename DefaultGemmKernel::Epilogue::OutputTileIterator,
typename EpilogueOutputOp0::ElementCompute,
ElementNorm,
ElementSum,
ElementSoftmaxCompute,
EpilogueOutputOp0,
kUseMasking
>;
using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue<
EpilogueVisitor,
typename DefaultGemmKernel::Epilogue
>::Epilogue;
using GemmKernel0 = cutlass::gemm::kernel::GemmGroupedWithEpilogueVistor<
typename DefaultGemmKernel::Mma,
Epilogue,
ThreadblockSwizzle,
kGroupScheduleMode0,
kInternalTranspose,
kUseMasking
>;
using GemmGrouped0 = cutlass::gemm::device::GemmGrouped<GemmKernel0>;
using ApplyFinalReductionDevice = cutlass::reduction::kernel::ApplySoftmaxFinalReduction<
ElementNorm,
ElementSum,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ElementSoftmaxCompute,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::ThreadblockShape,
true
>;
using GemmKernel1 = typename cutlass::gemm::kernel::DefaultGemmGroupedSoftmaxMainloopFusion<
ElementP,
LayoutP,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementQ>::value,
ElementV,
LayoutV,
cutlass::ComplexTransform::kNone,
128 / cutlass::sizeof_bits<ElementK>::value,
ElementNorm,
LayoutNorm,
ElementOutput,
LayoutO,
ElementAccumulator,
OperatorClass,
ArchTag,
ThreadblockShape1,
WarpShape1,
InstructionShape,
EpilogueOutputOp1,
ThreadblockSwizzle,
Stages1,
kGroupScheduleMode1
>::GemmKernel;
using GemmGrouped1 = cutlass::gemm::device::GemmGrouped<GemmKernel1>;
public:
/// Arguments class
struct Arguments {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Arguments reduction;
//
// Methods
//
Arguments():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Arguments(
cutlass::gemm::GemmCoord *problem_sizes0,
cutlass::gemm::GemmCoord *problem_sizes1,
int problem_count,
int threadblock_count,
ElementQ ** ptr_Q,
ElementK ** ptr_K,
ElementP ** ptr_P,
ElementP ** ptr_V,
ElementP ** ptr_O,
ElementNorm **ptr_Max,
ElementSum **ptr_Sum,
ElementP *block_P,
ElementNorm *block_Norm,
ElementSum *block_Sum,
int64_t *offset_P,
int64_t *offset_Norm_Device,
int64_t *offset_Sum_Device,
typename LayoutQ::Stride::LongIndex *ldq,
typename LayoutK::Stride::LongIndex *ldk,
typename LayoutP::Stride::LongIndex *ldp,
typename LayoutP::Stride::LongIndex *ldv,
typename LayoutP::Stride::LongIndex *ldo,
ElementAccumulator alpha0,
ElementAccumulator alpha1,
ElementAccumulator beta,
int head_number,
int batch_size,
int seq_length,
cutlass::gemm::GemmCoord *problem_sizes0_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes1_host = nullptr,
cutlass::gemm::GemmCoord *problem_sizes0_real = nullptr
):
problem_sizes0(problem_sizes0),
problem_sizes1(problem_sizes1),
problem_count(problem_count),
threadblock_count(threadblock_count),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_P(ptr_P),
ptr_V(ptr_V),
ptr_O(ptr_O),
ptr_Max(ptr_Max),
ptr_Sum(ptr_Sum),
block_P(block_P),
block_Norm(block_Norm),
block_Sum(block_Sum),
offset_P(offset_P),
offset_Norm_Device(offset_Norm_Device),
offset_Sum_Device(offset_Sum_Device),
ldq(ldq),
ldk(ldk),
ldp(ldp),
ldv(ldv),
ldo(ldo),
alpha0(alpha0),
alpha1(alpha1),
beta(beta),
head_number(head_number),
batch_size(batch_size),
seq_length(seq_length),
problem_sizes0_host(problem_sizes0_host),
problem_sizes1_host(problem_sizes1_host),
problem_sizes0_real(problem_sizes0_real),
reduction(
problem_sizes0,
block_Norm,
block_Sum,
offset_Norm_Device,
offset_Sum_Device
)
{
}
};
struct Params {
cutlass::gemm::GemmCoord *problem_sizes0;
cutlass::gemm::GemmCoord *problem_sizes0_real;
cutlass::gemm::GemmCoord *problem_sizes1;
int problem_count;
int threadblock_count;
ElementQ ** ptr_Q;
ElementK ** ptr_K;
ElementP ** ptr_P;
ElementP ** ptr_V;
ElementP ** ptr_O;
ElementNorm **ptr_Max;
ElementSum **ptr_Sum;
ElementP *block_P;
ElementNorm *block_Norm;
ElementSum *block_Sum;
int64_t *offset_P;
int64_t *offset_Norm_Device;
int64_t *offset_Sum_Device;
typename LayoutQ::Stride::LongIndex *ldq;
typename LayoutK::Stride::LongIndex *ldk;
typename LayoutP::Stride::LongIndex *ldp;
typename LayoutP::Stride::LongIndex *ldv;
typename LayoutP::Stride::LongIndex *ldo;
cutlass::gemm::GemmCoord *problem_sizes0_host;
cutlass::gemm::GemmCoord *problem_sizes1_host;
ElementAccumulator alpha0;
ElementAccumulator alpha1;
ElementAccumulator beta;
int head_number;
int batch_size;
int seq_length;
typename ApplyFinalReductionDevice::Params reduction;
Params():
problem_count(0),
threadblock_count(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
ptr_V(nullptr),
ptr_O(nullptr),
ptr_Max(nullptr),
ptr_Sum(nullptr),
block_P(nullptr),
block_Norm(nullptr),
block_Sum(nullptr),
offset_P(nullptr),
offset_Norm_Device(nullptr),
offset_Sum_Device(nullptr),
ldq(nullptr),
ldk(nullptr),
ldp(nullptr),
ldv(nullptr),
ldo(nullptr),
problem_sizes0(nullptr),
problem_sizes1(nullptr),
problem_sizes0_real(nullptr),
head_number(0),
batch_size(0),
seq_length(0)
{
}
Params(Arguments const &args, void *workspace = nullptr):
problem_sizes0(args.problem_sizes0),
problem_sizes1(args.problem_sizes1),
problem_count(args.problem_count),
threadblock_count(args.threadblock_count),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
ptr_V(args.ptr_V),
ptr_O(args.ptr_O),
ptr_Max(args.ptr_Max),
ptr_Sum(args.ptr_Sum),
block_P(args.block_P),
block_Norm(args.block_Norm),
block_Sum(args.block_Sum),
offset_P(args.offset_P),
offset_Norm_Device(args.offset_Norm_Device),
offset_Sum_Device(args.offset_Sum_Device),
ldq(args.ldq),
ldk(args.ldk),
ldp(args.ldp),
ldv(args.ldv),
ldo(args.ldo),
problem_sizes0_host(args.problem_sizes0_host),
problem_sizes1_host(args.problem_sizes1_host),
problem_sizes0_real(args.problem_sizes0_real),
alpha0(args.alpha0),
alpha1(args.alpha1),
beta(args.beta),
head_number(args.head_number),
batch_size(args.batch_size),
seq_length(args.seq_length),
reduction(args.reduction)
{
}
};
private:
Params params_;
GemmGrouped0 gemm_grouped0;
GemmGrouped1 gemm_grouped1;
public:
/// Ctor
FusedMultiHeadAttention() {
}
/// Initialize
Status initialize(Arguments const &args,
void *workspace0 = nullptr,
void *workspace1 = nullptr) {
params_ = Params(args);
typename GemmGrouped0::Arguments args_gemm0(
params_.problem_sizes0,
params_.problem_count,
params_.threadblock_count,
params_.ptr_Q,
params_.ptr_K,
params_.ptr_P,
params_.ptr_P,
params_.ptr_Max,
params_.ptr_Sum,
params_.ldq,
params_.ldk,
params_.ldp,
params_.ldp,
typename GemmGrouped0::GemmKernel::EpilogueVisitor::Arguments(
{
params_.alpha0,
params_.beta
}
),
params_.problem_sizes0_host,
params_.problem_sizes0_real
);
Status result0 = gemm_grouped0.initialize(args_gemm0, workspace0);
typename EpilogueOutputOp1::Params epilogue_op1(params_.alpha1, params_.beta);
typename GemmGrouped1::Arguments args_gemm1(
params_.problem_sizes1,
params_.problem_count,
params_.threadblock_count,
epilogue_op1,
params_.ptr_P,
params_.ptr_V,
params_.ptr_O,
params_.ptr_O,
(void**)params_.ptr_Max,
(void**)params_.ptr_Sum,
params_.ldp,
params_.ldv,
params_.ldo,
params_.ldo,
params_.problem_sizes1_host
);
Status result1 = gemm_grouped1.initialize(args_gemm1, workspace1);
if ((result0 == cutlass::Status::kSuccess) && (result1 == cutlass::Status::kSuccess) ) {
return cutlass::Status::kSuccess;
}else{
if (result0 != cutlass::Status::kSuccess) {
return result0;
}else{
return result1;
}
}
}
/// Run
Status run(cudaStream_t stream = nullptr) {
Status result = gemm_grouped0.run();
cudaError_t error_info;
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
int thread_per_block = 1024;
dim3 final_reduction_grid(params_.head_number * params_.batch_size);
dim3 final_reduction_block(thread_per_block);
cutlass::Kernel<ApplyFinalReductionDevice><<<
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionDevice::SharedStorage), stream
>>>(params_.reduction);
error_info = cudaGetLastError();
if (error_info != cudaSuccess) {
return cutlass::Status::kErrorInternal;
}
result = gemm_grouped1.run();
if (result != cutlass::Status::kSuccess) {
return cutlass::Status::kErrorInternal;
}
return cutlass::Status::kSuccess;
}
/// Function call operator
Status operator()(cudaStream_t stream = nullptr) {
return run(stream);
}
};
}