diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 22ef11a..a082e67 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -142,10 +142,11 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { } } + template inline __device__ void hrelu_() { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hrelu2(this->reg(ii)); + this->reg(ii) = fmha::hrelu2(this->reg(ii)); } } }; diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 932d2d6..e903c33 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -27,6 +27,8 @@ #pragma once +#include + namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -219,6 +221,7 @@ struct Gmem_tile_o { } // Store data to global memory. + template inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { int row_ = tidx_ / THREADS_PER_ROW; #pragma unroll @@ -237,7 +240,7 @@ struct Gmem_tile_o { float y = reinterpret_cast(src[ii].y); float z = reinterpret_cast(src[ii].z); float w = reinterpret_cast(src[ii].w); - uint2 out = float4_to_half4(x, y, z, w); + uint2 out = fmha::float4_pack(x, y, z, w); if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out); } @@ -245,7 +248,7 @@ struct Gmem_tile_o { } } - // Store data to global memory. + // Load data from global memory. inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) { static_assert(BYTES_PER_ELEMENT == 4); int row_ = tidx_ / THREADS_PER_ROW; @@ -366,36 +369,6 @@ struct Gmem_tile_mma_s : public Base { : Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) { } - // Store to global memory. - template - inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) { - #pragma unroll - for( int mi = 0; mi < M; mi++ ) { - #pragma unroll - for( int ni = 0; ni < N; ni++ ) { - - float tmp00 = softmax[2 * mi + 0][4 * ni + 0]; - float tmp01 = softmax[2 * mi + 0][4 * ni + 1]; - float tmp02 = softmax[2 * mi + 0][4 * ni + 2]; - float tmp03 = softmax[2 * mi + 0][4 * ni + 3]; - - float tmp10 = softmax[2 * mi + 1][4 * ni + 0]; - float tmp11 = softmax[2 * mi + 1][4 * ni + 1]; - float tmp12 = softmax[2 * mi + 1][4 * ni + 2]; - float tmp13 = softmax[2 * mi + 1][4 * ni + 3]; - - uint4 dst; - dst.x = fmha::float2_to_half2(tmp00, tmp01); - dst.y = fmha::float2_to_half2(tmp02, tmp03); - dst.z = fmha::float2_to_half2(tmp10, tmp11); - dst.w = fmha::float2_to_half2(tmp12, tmp13); - if( mask.is_valid(mi, ni, 0, 0) ) { - Base::store(dst, mi, ni); - } - } - } - } - // Store to global memory. template inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){ diff --git a/csrc/flash_attn/src/fmha/smem_tile.h b/csrc/flash_attn/src/fmha/smem_tile.h index 18579e5..4be3809 100644 --- a/csrc/flash_attn/src/fmha/smem_tile.h +++ b/csrc/flash_attn/src/fmha/smem_tile.h @@ -1384,7 +1384,7 @@ struct Smem_tile_mma_epilogue : public Base { } } - template + template inline __device__ void store(const Acc (&acc)[M][N]){ #pragma unroll for( int mi = 0; mi < M; mi++ ) { @@ -1401,10 +1401,10 @@ struct Smem_tile_mma_epilogue : public Base { float tmp12 = acc[mi][ni].elt(6); float tmp13 = acc[mi][ni].elt(7); - uint32_t x = fmha::float2_to_half2(tmp00, tmp01); - uint32_t y = fmha::float2_to_half2(tmp02, tmp03); - uint32_t z = fmha::float2_to_half2(tmp10, tmp11); - uint32_t w = fmha::float2_to_half2(tmp12, tmp13); + uint32_t x = fmha::float2_pack(tmp00, tmp01); + uint32_t y = fmha::float2_pack(tmp02, tmp03); + uint32_t z = fmha::float2_pack(tmp10, tmp11); + uint32_t w = fmha::float2_pack(tmp12, tmp13); // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x); diff --git a/csrc/flash_attn/src/fmha/softmax.h b/csrc/flash_attn/src/fmha/softmax.h index 2de6761..c4783ee 100644 --- a/csrc/flash_attn/src/fmha/softmax.h +++ b/csrc/flash_attn/src/fmha/softmax.h @@ -34,24 +34,6 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -struct Sum_ { - static constexpr bool IS_SUM = true; - static inline __device__ float apply(float x, float y) { - return x + y; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Max_ { - static constexpr bool IS_SUM = false; - static inline __device__ float apply(float x, float y) { - return x > y ? x : y; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - inline __device__ float apply_exp_(float x, float max) { return __expf(x - max); } @@ -508,7 +490,7 @@ struct Softmax : public Softmax_base { } // Pack the data to a fragment for the next GEMM. - template + template inline __device__ void pack(Fragment_a (&dst)[K][M]) const { #pragma unroll for( int mi = 0; mi < M; ++mi ) { @@ -528,10 +510,10 @@ struct Softmax : public Softmax_base { float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3]; // Pack to 4 registers. - dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01); - dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11); - dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03); - dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13); + dst[ki][mi].reg(0) = fmha::float2_pack(tmp_00, tmp_01); + dst[ki][mi].reg(1) = fmha::float2_pack(tmp_10, tmp_11); + dst[ki][mi].reg(2) = fmha::float2_pack(tmp_02, tmp_03); + dst[ki][mi].reg(3) = fmha::float2_pack(tmp_12, tmp_13); } } } diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index 1087f80..ecb8aef 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -33,6 +33,10 @@ #include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -310,12 +314,16 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { +template +inline __device__ uint32_t hrelu2(uint32_t x); + +template<> +inline __device__ uint32_t hrelu2<__half>(uint32_t x) { uint32_t res; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb)); -#else const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else asm volatile( \ "{\n" \ "\t .reg .f16x2 sela;\n" \ @@ -325,6 +333,19 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) { #endif return res; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static inline __device__ uint32_t habs2(uint32_t x) { uint32_t res; asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x)); @@ -332,7 +353,7 @@ static inline __device__ uint32_t habs2(uint32_t x) { } //////////////////////////////////////////////////////////////////////////////////////////////////// -// + template< typename T > static inline __device__ T clamp(T x, T lb, T ub) { return x < lb ? lb : (x > ub ? ub : x); @@ -370,6 +391,25 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ uint32_t float2_pack(float a, float b); + +template <> +inline __device__ uint32_t float2_pack<__half>(float a, float b) { + __half2 result = __floats2half2_rn(a, b); + return reinterpret_cast(result); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) { + __nv_bfloat162 result = __floats2bfloat162_rn(a, b); + return reinterpret_cast(result); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static inline __device__ uint32_t float_to_half2(float a) { return float2_to_half2(a,a); } @@ -391,6 +431,16 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float //////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline __device__ uint2 float4_pack(float x, float y, float z, float w) { + uint2 d; + d.x = float2_pack(x, y); + d.y = float2_pack(z, w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) { uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); @@ -404,7 +454,7 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c)); #else - d = hrelu2(hfma2(a, b, c)); + d = hrelu2<__half>(hfma2(a, b, c)); #endif return d; } @@ -481,32 +531,41 @@ static inline __device__ uint4 hadd8(uint4 a, uint4 b) { //////////////////////////////////////////////////////////////////////////////////////////////////// -// Converted two half2's into float, then take their dot product. -// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const __half2 b) { -static inline __device__ float hfma2_to_float(const __half2 a, const __half2 b) { - float2 af = __half22float2(a); - float2 bf = __half22float2(b); +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2 (&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = fmha::half2_unpack(a); + float2 bf = fmha::half2_unpack(b); return af.x * bf.x + af.y * bf.y; - // sum += af.x * bf.x + af.y * bf.y; - // sum = __fmaf_rn(sum, af.x, bf.x); - // sum = __fmaf_rn(sum, af.y, bf.y); - // float2 prod = __half22float2(__hmul2(a, b)); - // sum += prod.x + prod.y; } //////////////////////////////////////////////////////////////////////////////////////////////////// -// Converted two vectors of 8 half's into float, then take their dot product. -static inline __device__ float hmulsum8(const uint4 a, const uint4 b) { +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { float sum; - sum = fmha::hfma2_to_float(reinterpret_cast(a.x), - reinterpret_cast(b.x)); - sum += fmha::hfma2_to_float(reinterpret_cast(a.y), - reinterpret_cast(b.y)); - sum += fmha::hfma2_to_float(reinterpret_cast(a.z), - reinterpret_cast(b.z)); - sum += fmha::hfma2_to_float(reinterpret_cast(a.w), - reinterpret_cast(b.w)); + sum = fmha::hfma2_to_float(a.x, b.x); + sum += fmha::hfma2_to_float(a.y, b.y); + sum += fmha::hfma2_to_float(a.z, b.z); + sum += fmha::hfma2_to_float(a.w, b.w); return sum; } diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index e821e55..51a2f92 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -18,7 +18,7 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui Smem_dp_sum smem, const int buffer_idx) { #pragma unroll for (int mi = 0; mi < M; ++mi) { - sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi])); + sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi])); } static_assert(M == 1); smem.store(sum[0], buffer_idx); @@ -358,7 +358,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.pack(frag_p); + softmax.template pack<__half>(frag_p); // Store s * dmask to smem for transpose smem_s.store(frag_p); @@ -463,7 +463,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, if (is_first_read) { softmax.subtract_dp_sum(dp_sum); } Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - softmax.pack(frag_dp); + softmax.template pack<__half>(frag_dp); if (!Is_dropout) { #pragma unroll @@ -544,7 +544,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].hrelu_(); + frag_s[ki][mi].template hrelu_<__half>(); } } } @@ -638,7 +638,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // } dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f); // Output the values. - gmem_dq.store(dq_out, 0); + gmem_dq.template store<__half>(dq_out, 0); } else { // Output the values. gmem_dq_tmp.store(dq_out, 0); @@ -693,11 +693,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // the total amount of shared mem? // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.store(acc_dv); + smem_dv.template store<__half>(acc_dv); // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.store(acc_dk); + smem_dk.template store<__half>(acc_dk); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index 8323029..d2c2d01 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -335,7 +335,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.pack(frag_p); + softmax.template pack<__half>(frag_p); if (Return_softmax) { gmem_s.store(frag_p, mask); if (not_last_iter) { @@ -353,7 +353,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].hrelu_(); + frag_p[ki][mi].template hrelu_<__half>(); } } } @@ -371,7 +371,6 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // The mapping from tidx to rows changes between the softmax and the O-reduction. // So we recalculate the max. float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - // TODO: not sure if this is right for seqlen 128 or 256 int rows[Gmem_tile_o::STGS_PER_LOOP]; for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; @@ -467,7 +466,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Output the values. if (is_final_write) { - gmem_o.store(out, 0); + gmem_o.template store<__half>(out, 0); } else { gmem_o_tmp.store(out, 0); } diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 15a50bb..00d1681 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -12,14 +12,16 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale, Gmem_softmax_sum gmem_softmax_d, int tidx) { float sum[M]; fmha::SumOp sum_op; #pragma unroll for (int mi = 0; mi < M; ++mi) { - sum[mi] = fmha::Allreduce::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op) * scale; + sum[mi] = fmha::Allreduce::run( + fmha::hmulsum8(do_[mi], o[mi]), sum_op + ) * scale; } const int dp_sum_row = tidx / THREADS_PER_ROW; if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { @@ -212,7 +214,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng gmem_q.commit(gemm_q_k.smem_q); gmem_do.commit(smem_do); if (Is_first) { - dot_do_o( + dot_do_o( gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx ); } @@ -331,7 +333,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.pack(frag_p); + softmax.template pack<__half>(frag_p); // Store s * dmask to smem for transpose smem_s.store(frag_p); @@ -422,7 +424,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng } } - softmax.pack(frag_p); + softmax.template pack<__half>(frag_p); // Store dp to smem for transpose smem_dp.store(frag_p); @@ -473,7 +475,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].hrelu_(); + frag_s[ki][mi].template hrelu_<__half>(); } } } @@ -517,7 +519,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng if(l < steps - 1) { gmem_do.commit(smem_do); if (Is_first) { - dot_do_o( + dot_do_o( gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx ); } @@ -573,7 +575,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout); } // Output the values. - gmem_dq.store(dq_out, 0); + gmem_dq.template store<__half>(dq_out, 0); // Move to the next part of the output. gmem_dq.move(); } else { @@ -627,11 +629,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // the total amount of shared mem? // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.store(acc_dv); + smem_dv.template store<__half>(acc_dv); // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.store(acc_dk); + smem_dk.template store<__half>(acc_dk); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; @@ -644,9 +646,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng uint4 dk_out[Smem_tile_dk::NUM_LDS]; smem_dk.load(dk_out); - // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { - // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); - // } Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false); if (!Is_first) { gmem_dk.move(loop_step_idx); @@ -694,4 +693,4 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace fmha \ No newline at end of file +} // namespace fmha diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index c8dcee8..2c00888 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -466,7 +466,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.pack(frag_p); + softmax.template pack<__half>(frag_p); if (Return_softmax) { gmem_s.store(frag_p, mask); gmem_s.move(); @@ -482,7 +482,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].hrelu_(); + frag_p[ki][mi].template hrelu_<__half>(); } } } @@ -509,7 +509,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // The mapping from tidx to rows changes between the softmax and the // O-reduction. So we recalculate the max. float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M]; - // TODO: not sure if this is right for seqlen 128 or 256 int rows[Gmem_tile_o::STGS_PER_LOOP]; for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) { rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG; @@ -606,7 +605,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i // Output the values. if (is_final_write) { - gmem_o.store(out, 0); + gmem_o.template store<__half>(out, 0); gmem_o.move(); } else { gmem_o_tmp.store(out, 0);