Use Cutlass gemm as WarpMma

This commit is contained in:
Tri Dao 2022-06-02 10:33:32 -07:00
parent e78e7c9553
commit 14dc326e59
5 changed files with 80 additions and 30 deletions

View File

@ -29,6 +29,13 @@
#include <fmha/utils.h>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
#include "cutlass/layout/layout.h"
#include <cutlass/arch/mma.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -247,6 +254,49 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using Element = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
using FragmentA = typename WarpMma::FragmentA;
using FragmentB = typename WarpMma::FragmentB;
using FragmentC = typename WarpMma::FragmentC;
static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
WarpMma mma_op;
mma_op(c_cl, a_cl, b_cl, c_cl);
// The modified c_cl is not copied back into acc, idk why
#pragma unroll
for (int mi = 0; mi < M; mi++) {
#pragma unroll
for (int ni = 0; ni < N; ni++) {
#pragma unroll
for (int i =0; i < 8; i++) {
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,

View File

@ -408,9 +408,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
@ -424,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
{
int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
}
@ -515,14 +515,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert(Gmem_tile_dq::LOOPS == 1);
@ -555,13 +555,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// __syncthreads();
@ -613,13 +613,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Make sure dQ is in shared memory.

View File

@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.

View File

@ -383,9 +383,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
smem_do.load(frag_do[ki & 1], ki);
if (!Kernel_traits::V_IN_REGS) {
smem_v.load(frag_v[ki & 1], ki);
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
@ -399,9 +399,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
{
int ki = Mma_tile_p::MMAS_K;
if (!Kernel_traits::V_IN_REGS) {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
} else {
fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
}
}
@ -484,14 +484,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_kt.load(frag_kt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dq::MMAS_K;
fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
static_assert(Gmem_tile_dq::LOOPS == 1);
@ -524,13 +524,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_dot.load(frag_dot[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
}
// __syncthreads();
@ -579,13 +579,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dkv::MMAS_K;
fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Make sure dQ is in shared memory.

View File

@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values.
Base::smem_q.load(Base::frag_q[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
}
}
@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base::smem_q.load(Base::frag_q[ki & 1], ki);
Base::smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
}
@ -497,7 +497,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Do this part of O = P^T * V^T.
#pragma unroll
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
fmha::gemm(acc_o, frag_p[ki], frag_v[ki]);
fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]);
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.