452 lines
18 KiB
C++
452 lines
18 KiB
C++
/******************************************************************************
|
|
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without
|
|
* modification, are permitted provided that the following conditions are met:
|
|
* * Redistributions of source code must retain the above copyright
|
|
* notice, this list of conditions and the following disclaimer.
|
|
* * Redistributions in binary form must reproduce the above copyright
|
|
* notice, this list of conditions and the following disclaimer in the
|
|
* documentation and/or other materials provided with the distribution.
|
|
* * Neither the name of the NVIDIA CORPORATION nor the
|
|
* names of its contributors may be used to endorse or promote products
|
|
* derived from this software without specific prior written permission.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
|
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
|
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
|
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
|
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
|
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
|
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
|
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
******************************************************************************/
|
|
|
|
#pragma once
|
|
|
|
#include <fmha/utils.h>
|
|
|
|
#include "cutlass/cutlass.h"
|
|
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
|
|
#include "cutlass/layout/layout.h"
|
|
#include <cutlass/arch/mma.h>
|
|
#include <cutlass/array.h>
|
|
#include <cutlass/numeric_types.h>
|
|
|
|
namespace fmha {
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
|
|
struct Fragment_base_ {
|
|
|
|
// The data type.
|
|
using Data_type = Data_type_;
|
|
// default input type
|
|
using Input_type_ = Data_type_;
|
|
// Does it store the array of elements.
|
|
static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
|
|
// The number of elements.
|
|
static constexpr int NUM_ELTS = NUM_ELTS_;
|
|
// The size of element in bits.
|
|
static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
|
|
// The size of byte of a single register.
|
|
static constexpr int BYTES_PER_REG = 4;
|
|
// The size in bits.
|
|
static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
|
|
// The number of registers needed to store the fragment.
|
|
static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
|
|
// The size in bytes (as returned by sizeof(Fragment_base<>).
|
|
static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
|
|
// The alignment.
|
|
static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<
|
|
// The type of the elements.
|
|
typename Data_type_,
|
|
// The number of elements.
|
|
int NUM_ELTS_,
|
|
// The alignment if you want to force a value -- use 0 otherwise.
|
|
int ALIGNMENT_ = 0,
|
|
// The base class.
|
|
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
|
|
>
|
|
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
|
|
|
|
// The size of a load/store.
|
|
static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);
|
|
|
|
// Clear the fragment. Using PTX in that code seems to produce better SASS...
|
|
inline __device__ void clear() {
|
|
#pragma unroll
|
|
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
|
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
|
|
}
|
|
}
|
|
|
|
// Immutable access to a register.
|
|
inline __device__ const uint32_t& reg(int ii) const {
|
|
return this->regs_[ii];
|
|
}
|
|
|
|
// Mutable access to a register.
|
|
inline __device__ uint32_t& reg(int ii) {
|
|
return this->regs_[ii];
|
|
}
|
|
|
|
uint32_t regs_[Base_::NUM_REGS];
|
|
|
|
// Immutable access to the elements.
|
|
inline __device__ const Data_type_& elt(int ii) const {
|
|
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
|
|
}
|
|
|
|
// Mutable access to the elements.
|
|
inline __device__ Data_type_& elt(int ii) {
|
|
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
|
|
}
|
|
|
|
// Immutable access to the elements with a cast.
|
|
template< typename Cast_type >
|
|
inline __device__ const Cast_type& elt_as(int ii) const {
|
|
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
|
|
}
|
|
|
|
// Mutable access to the elements.
|
|
template< typename Cast_type >
|
|
inline __device__ Cast_type& elt_as(int ii) {
|
|
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
|
|
}
|
|
|
|
// Add another fragment.
|
|
inline __device__ void add(const Fragment &other) {
|
|
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
|
|
// Also are we doing int addition or __half2 addition?
|
|
#pragma unroll
|
|
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
|
|
this->elt(ii) += other.elt(ii);
|
|
}
|
|
}
|
|
|
|
// Multiply by another fragment.
|
|
inline __device__ void hmul(const Fragment &other) {
|
|
#pragma unroll
|
|
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
|
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
|
|
}
|
|
}
|
|
|
|
template <typename elem_type>
|
|
inline __device__ void hrelu_() {
|
|
#pragma unroll
|
|
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
|
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
|
|
}
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< typename Layout >
|
|
struct Fragment_a : public Fragment<uint16_t, 8> {
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< typename Layout >
|
|
struct Fragment_b : public Fragment<uint16_t, 8> {
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
struct Fragment_accumulator : public Fragment<float, 8> {
|
|
|
|
// The base class.
|
|
using Base = Fragment<float, 8>;
|
|
|
|
// Add two fragments.
|
|
template< typename Other_fragment_ >
|
|
inline __device__ void add(const Other_fragment_ &other) {
|
|
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
|
this->elt(ii) = this->elt(ii) + other.elt(ii);
|
|
}
|
|
}
|
|
|
|
inline __device__ void mul_(const float other) {
|
|
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
|
this->elt(ii) *= other;
|
|
}
|
|
}
|
|
|
|
// Do the HMMA.
|
|
template< typename Layout_a, typename Layout_b >
|
|
inline __device__ void mma(const Fragment_a<Layout_a> &a,
|
|
const Fragment_b<Layout_b> &b) {
|
|
asm volatile( \
|
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
|
" {%0, %1, %2, %3}, \n" \
|
|
" {%4, %5, %6, %7}, \n" \
|
|
" {%8, %9}, \n" \
|
|
" {%0, %1, %2, %3}; \n" \
|
|
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
|
|
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
|
, "r"(b.reg(0)), "r"(b.reg(1)));
|
|
asm volatile( \
|
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
|
" {%0, %1, %2, %3}, \n" \
|
|
" {%4, %5, %6, %7}, \n" \
|
|
" {%8, %9}, \n" \
|
|
" {%0, %1, %2, %3}; \n" \
|
|
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
|
|
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
|
, "r"(b.reg(2)), "r"(b.reg(3)));
|
|
}
|
|
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< typename Fragment, int M, int N >
|
|
inline __device__ void clear(Fragment (&frag)[M][N]) {
|
|
#pragma unroll
|
|
for( int mi = 0; mi < M; ++mi ) {
|
|
#pragma unroll
|
|
for( int ni = 0; ni < N; ++ni ) {
|
|
frag[mi][ni].clear();
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< typename Accumulator_type, int WARPS_K >
|
|
struct Clear_accumulator {
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template< int WARPS_K >
|
|
struct Clear_accumulator<float, WARPS_K> {
|
|
template< typename Acc, int M, int N >
|
|
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
|
|
fmha::clear(acc);
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<typename Acc, typename A, typename B, int M, int N>
|
|
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
|
|
|
#pragma unroll
|
|
for( int mi = 0; mi < M; ++mi ) {
|
|
#pragma unroll
|
|
for( int ni = 0; ni < N; ++ni ) {
|
|
acc[mi][ni].mma(a[mi], b[ni]);
|
|
}
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////
|
|
/// Statically maps half types => cutlass data types
|
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
|
template <typename Type_>
|
|
struct HalfTypeToCutlassType { using Type = Type_; };
|
|
|
|
/// Statically maps __half => cutlass::half_t
|
|
template <> struct HalfTypeToCutlassType<__half> {
|
|
using Type = cutlass::half_t;
|
|
};
|
|
|
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
|
template <> struct HalfTypeToCutlassType<__nv_bfloat16> {
|
|
using Type = cutlass::bfloat16_t;
|
|
};
|
|
#endif
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<typename elem_type, typename Acc, typename A, typename B, int M, int N>
|
|
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
|
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
|
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
|
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
|
#else
|
|
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
|
// TD [2022-06-02] We don't support Volta (SM70) yet.
|
|
assert(0);
|
|
#endif
|
|
using Element = typename HalfTypeToCutlassType<elem_type>::Type;
|
|
using ElementC = float;
|
|
using LayoutA = cutlass::layout::RowMajor;
|
|
using LayoutB = cutlass::layout::ColumnMajor;
|
|
|
|
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
|
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
|
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
|
|
|
|
constexpr int kIters = Shape::kK / InstructionShape::kK;
|
|
// using FragmentA = typename WarpMma::FragmentA;
|
|
// using FragmentB = typename WarpMma::FragmentB;
|
|
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
|
|
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
|
|
using FragmentC = typename WarpMma::FragmentC;
|
|
|
|
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
|
|
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
|
|
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
|
|
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
|
|
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
|
|
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
|
|
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
|
|
// }
|
|
|
|
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
|
|
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
|
|
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
|
|
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
|
|
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
|
|
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
|
|
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
|
|
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
|
|
FragmentA a_cl[kIters][M];
|
|
FragmentA b_cl[kIters][N];
|
|
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
|
|
#pragma unroll
|
|
for (int iter = 0; iter < kIters; iter++) {
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++) {
|
|
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
|
|
#pragma unroll
|
|
for (int ki = 0; ki < kRegs; ki++) {
|
|
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
|
|
}
|
|
}
|
|
}
|
|
#pragma unroll
|
|
for (int iter = 0; iter < kIters; iter++) {
|
|
#pragma unroll
|
|
for (int ni = 0; ni < N; ni++) {
|
|
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
|
|
#pragma unroll
|
|
for (int ki = 0; ki < kRegs; ki++) {
|
|
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
|
|
// TD [2022-06-02] For some reason the order for frag_b is different.
|
|
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
|
|
}
|
|
}
|
|
}
|
|
|
|
WarpMma mma_op;
|
|
// mma_op(c_cl, a_cl, b_cl, c_cl);
|
|
#pragma unroll
|
|
for (int iter = 0; iter < kIters; iter++) {
|
|
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
|
|
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
|
|
}
|
|
|
|
// The modified c_cl is not copied back into acc, idk why
|
|
#pragma unroll
|
|
for (int mi = 0; mi < M; mi++) {
|
|
#pragma unroll
|
|
for (int ni = 0; ni < N; ni++) {
|
|
#pragma unroll
|
|
for (int i =0; i < 8; i++) {
|
|
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<
|
|
// The number of rows in the CTA tile.
|
|
int M_,
|
|
// The number of cols in the CTA tile.
|
|
int N_,
|
|
// The number of elements in the the K dimension of the GEMM loop.
|
|
int K_,
|
|
// The number of rows of warps.
|
|
int WARPS_M_,
|
|
// The number of cols of warps.
|
|
int WARPS_N_,
|
|
// The number of warps in the K dimension of the GEMM loop.
|
|
int WARPS_K_>
|
|
struct Cta_tile_ {
|
|
|
|
static constexpr int M = M_, N = N_, K = K_;
|
|
// The number of warps.
|
|
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
|
|
// The number of warps per CTA.
|
|
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
|
|
// The number of threads per warp.
|
|
static constexpr int THREADS_PER_WARP = 32;
|
|
// The number of threads per CTA.
|
|
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<typename Cta_tile>
|
|
struct Hmma_tile {
|
|
// The number of elements computed with a single warp-MMA.
|
|
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
|
|
|
|
// The number of elements computed with a single CTA-MMA.
|
|
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
|
|
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
|
|
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
|
|
|
|
// The number of MMAs needed to compute the GEMM.
|
|
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
|
|
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
|
|
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
|
|
|
|
// // The number of elements computed per warp.
|
|
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
|
|
// N_PER_WARP = MMAS_N * N_PER_MMA,
|
|
// K_PER_WARP = MMAS_K * K_PER_MMA;
|
|
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
using A_type = uint16_t;
|
|
using B_type = uint16_t;
|
|
using C_type = uint16_t;
|
|
using Accumulator_type = float;
|
|
using Epilogue_type = float;
|
|
|
|
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
|
|
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
|
|
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
|
|
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
template<typename Cta_tile_>
|
|
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
|
|
Cta_tile_::N,
|
|
Next_power_of_two<Cta_tile_::K>::VALUE,
|
|
Cta_tile_::WARPS_M,
|
|
Cta_tile_::WARPS_N,
|
|
Cta_tile_::WARPS_K>;
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
} // namespace fmha
|