cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h
dan_the_3rd 146d314057
Update fMHA kernels (#992)
* Update fMHA kernels

Upstream recent changes to fMHA that we did in xFormers.
Previous version in CUTLASS: facebookresearch/xformers@b6be33a
Updating to: facebookresearch/xformers@55a4798

* minor changes

* make var work

---------

Co-authored-by: danthe3rd <danthe3rd>
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
2023-07-12 22:30:46 -04:00

258 lines
11 KiB
C++

/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
////////////////////////////////////////////////////////////////////////////////
// Some helper functions
////////////////////////////////////////////////////////////////////////////////
#define DISPATCH_TYPES(tensor, func) \
{ \
if (query.scalar_type() == at::ScalarType::Float) { \
using scalar_t = float; \
func(); \
} else if (query.scalar_type() == at::ScalarType::Half) { \
using scalar_t = cutlass::half_t; \
func(); \
} else if (query.scalar_type() == at::ScalarType::BFloat16) { \
using scalar_t = cutlass::bfloat16_t; \
func(); \
} else { \
XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \
} \
}
#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \
{ \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
F(); \
} else { \
constexpr bool BOOL_NAME = false; \
F(); \
} \
}
#define DISPATCH_ARCHTAG(CC, func) \
{ \
if (CC >= 80) { \
using ArchTag = cutlass::arch::Sm80; \
func(); \
} else if (CC >= 75) { \
using ArchTag = cutlass::arch::Sm75; \
func(); \
} else if (CC >= 70) { \
using ArchTag = cutlass::arch::Sm70; \
func(); \
} else if (CC >= 50) { \
using ArchTag = cutlass::arch::Sm50; \
func(); \
} else { \
XFORMERS_CHECK( \
false, \
"Your device is too old. We require compute capability >= 50"); \
} \
}
#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
XFORMERS_CHECK(TENSOR.is_contiguous());
#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
XFORMERS_CHECK( \
TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
#ifdef TORCH_CHECK
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
XFORMERS_CHECK( \
uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define XFORMERS_CHECK TORCH_CHECK
#elif defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
return false; \
}
#else
#include <iostream>
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << "'" #COND "' failed: " << ERR << "\n"; \
return false; \
}
#endif
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
XFORMERS_CHECK( \
B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
}
namespace gemm_kernel_utils {
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) {
return ((n + m - 1) / m) * m;
}
////////////////////////////////////////////////////////////////////////////////
// Determine the type of GEMM we do (TensorCores or not, Shapes ...)
// TODO: Maybe we could rely on Cutlass's DefaultGemm templates
////////////////////////////////////////////////////////////////////////////////
// Fallback to Simt (FMA on cuda cores) if not in a special case below
template <typename ArchTag, typename scalar_t_, typename Enable = void>
struct DefaultGemmType {
static constexpr int ThreadK = 8;
static constexpr int WarpK = 8;
static constexpr int kMinimumAlignment = 1;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using OpClass = cutlass::arch::OpClassSimt;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f32
template <typename ArchTag>
struct DefaultGemmType<
ArchTag,
float,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAddFastF32;
};
// Specialization for tensorcores with f16/bf16 - Sm75+
template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<
ArchTag,
scalar_t,
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Specialization for tensorcores with f16 - Volta
template <>
struct DefaultGemmType<cutlass::arch::Sm70, cutlass::half_t, void> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 2;
using OpClass = cutlass::arch::OpClassTensorOp;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using Operator = cutlass::arch::OpMultiplyAdd;
};
// Enables to do
// `auto x = kCondition ? fa(arg) : fb(arg)`
// when `fa` and `fb` have different types
template <bool kVal, typename TA, typename TB>
struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(ta(arg)) {
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(tb(arg)) {
return tb(arg);
}
};
////////////////////////////////////////////////////////////////////////////////
// Mark a variable as warp-uniform - enables some compiler optimizations
// The cheapest way to do it is just to broadcast it from lane 0
////////////////////////////////////////////////////////////////////////////////
template <typename T>
CUTLASS_DEVICE T warp_uniform(T value) {
struct {
union {
T value;
uint32_t asInt;
};
} p;
p.value = value;
p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0);
return p.value;
}
template <typename T>
CUTLASS_DEVICE T* warp_uniform(T* ptr) {
struct {
union {
T* ptr;
uint32_t asInt[2];
};
} p;
p.ptr = ptr;
p.asInt[0] = warp_uniform(p.asInt[0]);
p.asInt[1] = warp_uniform(p.asInt[1]);
return p.ptr;
}
} // namespace gemm_kernel_utils