diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 5cad7f3..9721458 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -29,6 +29,13 @@ #include +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/layout/layout.h" +#include +#include +#include + namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -247,6 +254,49 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { + using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type; + + using FragmentA = typename WarpMma::FragmentA; + using FragmentB = typename WarpMma::FragmentB; + using FragmentC = typename WarpMma::FragmentC; + + static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS); + static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS); + static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS); + const FragmentA a_cl = reinterpret_cast(a); + const FragmentB b_cl = reinterpret_cast(b); + FragmentC c_cl = reinterpret_cast(acc); + + WarpMma mma_op; + mma_op(c_cl, a_cl, b_cl, c_cl); + + // The modified c_cl is not copied back into acc, idk why + #pragma unroll + for (int mi = 0; mi < M; mi++) { + #pragma unroll + for (int ni = 0; ni < N; ni++) { + #pragma unroll + for (int i =0; i < 8; i++) { + acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i]; + } + } + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template< // The number of rows in the CTA tile. int M_, diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 097eaee..f65b696 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -408,9 +408,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -424,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -515,14 +515,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } static_assert(Gmem_tile_dq::LOOPS == 1); @@ -555,13 +555,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // __syncthreads(); @@ -613,13 +613,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index cb0074d..33674f1 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); } // The mapping from tidx to rows changes between the softmax and the O-reduction. diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index c50e5e2..45732a1 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -383,9 +383,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -399,9 +399,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -484,14 +484,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } static_assert(Gmem_tile_dq::LOOPS == 1); @@ -524,13 +524,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // __syncthreads(); @@ -579,13 +579,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index eed2331..93aef12 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base { // Trigger the load from shared memory for the next series of Q values. Base::smem_q.load(Base::frag_q[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } } @@ -175,12 +175,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base { Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; - fmha::gemm(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } } @@ -497,7 +497,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); } // The mapping from tidx to rows changes between the softmax and the O-reduction.