diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index a082e67..a142f0b 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -253,9 +253,26 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) } } +//////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps half types => cutlass data types +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct HalfTypeToCutlassType { using Type = Type_; }; + +/// Statically maps __half => cutlass::half_t +template <> struct HalfTypeToCutlassType<__half> { + using Type = cutlass::half_t; +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) +template <> struct HalfTypeToCutlassType<__nv_bfloat16> { + using Type = cutlass::bfloat16_t; +}; +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) { using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>; #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 @@ -267,7 +284,7 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N // TD [2022-06-02] We don't support Volta (SM70) yet. assert(0); #endif - using Element = cutlass::half_t; + using Element = typename HalfTypeToCutlassType::Type; using ElementC = float; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 51a2f92..afa2fd1 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -407,9 +407,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -423,9 +423,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -514,14 +514,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } static_assert(Gmem_tile_dq::LOOPS == 1); @@ -554,13 +554,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // __syncthreads(); @@ -612,13 +612,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index d2c2d01..3cec1c6 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); } // The mapping from tidx to rows changes between the softmax and the O-reduction. diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 00d1681..bf65b9a 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -369,9 +369,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -385,9 +385,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -442,14 +442,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); - // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); + fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } static_assert(Gmem_tile_dq::LOOPS == 1); @@ -485,13 +485,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { @@ -542,13 +542,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 2c00888..21b3328 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base { // Trigger the load from shared memory for the next series of Q values. Base::smem_q.load(Base::frag_q[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); + fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]); } } @@ -175,12 +175,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base { Base::smem_q.load(Base::frag_q[ki & 1], ki); Base::smem_k.load(frag_k[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; - fmha::gemm_cl(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); + fmha::gemm_cl<__half>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]); } } @@ -494,7 +494,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); // if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) { // float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki])); // float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));