Refactor to template on __half, implement bf16 util functions

This commit is contained in:
Tri Dao 2022-07-08 15:18:58 -07:00
parent 2dc1b205f6
commit e518a4b327
9 changed files with 128 additions and 116 deletions

View File

@ -142,10 +142,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
}
}
template <typename elem_type>
inline __device__ void hrelu_() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
this->reg(ii) = fmha::hrelu2(this->reg(ii));
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
}
}
};

View File

@ -27,6 +27,8 @@
#pragma once
#include <cuda_fp16.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -219,6 +221,7 @@ struct Gmem_tile_o {
}
// Store data to global memory.
template<typename elem_type=__half>
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
int row_ = tidx_ / THREADS_PER_ROW;
#pragma unroll
@ -237,7 +240,7 @@ struct Gmem_tile_o {
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w);
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
}
@ -245,7 +248,7 @@ struct Gmem_tile_o {
}
}
// Store data to global memory.
// Load data from global memory.
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
static_assert(BYTES_PER_ELEMENT == 4);
int row_ = tidx_ / THREADS_PER_ROW;
@ -366,36 +369,6 @@ struct Gmem_tile_mma_s : public Base {
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask>
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
uint4 dst;
dst.x = fmha::float2_to_half2(tmp00, tmp01);
dst.y = fmha::float2_to_half2(tmp02, tmp03);
dst.z = fmha::float2_to_half2(tmp10, tmp11);
dst.w = fmha::float2_to_half2(tmp12, tmp13);
if( mask.is_valid(mi, ni, 0, 0) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){

View File

@ -1384,7 +1384,7 @@ struct Smem_tile_mma_epilogue : public Base {
}
}
template<int M, int N>
template<typename elem_type=__half, int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
@ -1401,10 +1401,10 @@ struct Smem_tile_mma_epilogue : public Base {
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);

View File

@ -34,24 +34,6 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Sum_ {
static constexpr bool IS_SUM = true;
static inline __device__ float apply(float x, float y) {
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Max_ {
static constexpr bool IS_SUM = false;
static inline __device__ float apply(float x, float y) {
return x > y ? x : y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
@ -508,7 +490,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
// Pack the data to a fragment for the next GEMM.
template<int K, int M>
template<typename elem_type=__half, int K, int M>
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
@ -528,10 +510,10 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
// Pack to 4 registers.
dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
}
}
}

View File

@ -33,6 +33,10 @@
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -310,12 +314,16 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
template<typename T>
inline __device__ uint32_t hrelu2(uint32_t x);
template<>
inline __device__ uint32_t hrelu2<__half>(uint32_t x) {
uint32_t res;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
#else
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
@ -325,6 +333,19 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
#endif
return res;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
@ -332,7 +353,7 @@ static inline __device__ uint32_t habs2(uint32_t x) {
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template< typename T >
static inline __device__ T clamp(T x, T lb, T ub) {
return x < lb ? lb : (x > ub ? ub : x);
@ -370,6 +391,25 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t float2_pack(float a, float b);
template <>
inline __device__ uint32_t float2_pack<__half>(float a, float b) {
__half2 result = __floats2half2_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) {
__nv_bfloat162 result = __floats2bfloat162_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a) {
return float2_to_half2(a,a);
}
@ -391,6 +431,16 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint2 float4_pack(float x, float y, float z, float w) {
uint2 d;
d.x = float2_pack<T>(x, y);
d.y = float2_pack<T>(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
@ -404,7 +454,7 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2(hfma2(a, b, c));
d = hrelu2<__half>(hfma2(a, b, c));
#endif
return d;
}
@ -481,32 +531,41 @@ static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's into float, then take their dot product.
// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const __half2 b) {
static inline __device__ float hfma2_to_float(const __half2 a, const __half2 b) {
float2 af = __half22float2(a);
float2 bf = __half22float2(b);
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = fmha::half2_unpack<T>(a);
float2 bf = fmha::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
// sum += af.x * bf.x + af.y * bf.y;
// sum = __fmaf_rn(sum, af.x, bf.x);
// sum = __fmaf_rn(sum, af.y, bf.y);
// float2 prod = __half22float2(__hmul2(a, b));
// sum += prod.x + prod.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's into float, then take their dot product.
static inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.x),
reinterpret_cast<const __half2&>(b.x));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.y),
reinterpret_cast<const __half2&>(b.y));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.z),
reinterpret_cast<const __half2&>(b.z));
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.w),
reinterpret_cast<const __half2&>(b.w));
sum = fmha::hfma2_to_float<T>(a.x, b.x);
sum += fmha::hfma2_to_float<T>(a.y, b.y);
sum += fmha::hfma2_to_float<T>(a.z, b.z);
sum += fmha::hfma2_to_float<T>(a.w, b.w);
return sum;
}

View File

@ -18,7 +18,7 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui
Smem_dp_sum smem, const int buffer_idx) {
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi]));
}
static_assert(M == 1);
smem.store(sum[0], buffer_idx);
@ -358,7 +358,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
softmax.template pack<__half>(frag_p);
// Store s * dmask to smem for transpose
smem_s.store(frag_p);
@ -463,7 +463,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
if (is_first_read) { softmax.subtract_dp_sum(dp_sum); }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp);
softmax.template pack<__half>(frag_dp);
if (!Is_dropout) {
#pragma unroll
@ -544,7 +544,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].hrelu_();
frag_s[ki][mi].template hrelu_<__half>();
}
}
}
@ -638,7 +638,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// }
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
// Output the values.
gmem_dq.store(dq_out, 0);
gmem_dq.template store<__half>(dq_out, 0);
} else {
// Output the values.
gmem_dq_tmp.store(dq_out, 0);
@ -693,11 +693,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params &params,
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.store(acc_dv);
smem_dv.template store<__half>(acc_dv);
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.store(acc_dk);
smem_dk.template store<__half>(acc_dk);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];

View File

@ -335,7 +335,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
softmax.template pack<__half>(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
if (not_last_iter) {
@ -353,7 +353,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].hrelu_();
frag_p[ki][mi].template hrelu_<__half>();
}
}
}
@ -371,7 +371,6 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
// TODO: not sure if this is right for seqlen 128 or 256
int rows[Gmem_tile_o::STGS_PER_LOOP];
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
@ -467,7 +466,7 @@ inline __device__ void device_block_1xN_(const Params &params, const int bidb, c
// Output the values.
if (is_final_write) {
gmem_o.store(out, 0);
gmem_o.template store<__half>(out, 0);
} else {
gmem_o_tmp.store(out, 0);
}

View File

@ -12,14 +12,16 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
template <int ROWS, int THREADS_PER_ROW, typename elem_type=__half, int M, typename Gmem_softmax_sum>
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
Gmem_softmax_sum gmem_softmax_d, int tidx) {
float sum[M];
fmha::SumOp<float> sum_op;
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op) * scale;
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(
fmha::hmulsum8<elem_type>(do_[mi], o[mi]), sum_op
) * scale;
}
const int dp_sum_row = tidx / THREADS_PER_ROW;
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
@ -212,7 +214,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
@ -331,7 +333,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
softmax.template pack<__half>(frag_p);
// Store s * dmask to smem for transpose
smem_s.store(frag_p);
@ -422,7 +424,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
softmax.pack(frag_p);
softmax.template pack<__half>(frag_p);
// Store dp to smem for transpose
smem_dp.store(frag_p);
@ -473,7 +475,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
frag_s[ki][mi].hrelu_();
frag_s[ki][mi].template hrelu_<__half>();
}
}
}
@ -517,7 +519,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if(l < steps - 1) {
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
);
}
@ -573,7 +575,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
}
// Output the values.
gmem_dq.store(dq_out, 0);
gmem_dq.template store<__half>(dq_out, 0);
// Move to the next part of the output.
gmem_dq.move();
} else {
@ -627,11 +629,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// the total amount of shared mem?
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[0], tidx);
smem_dv.store(acc_dv);
smem_dv.template store<__half>(acc_dv);
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
smem_dk.store(acc_dk);
smem_dk.template store<__half>(acc_dk);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
@ -644,9 +646,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
// }
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
if (!Is_first) {
gmem_dk.move(loop_step_idx);
@ -694,4 +693,4 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
} // namespace fmha

View File

@ -466,7 +466,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack(frag_p);
softmax.template pack<__half>(frag_p);
if (Return_softmax) {
gmem_s.store(frag_p, mask);
gmem_s.move();
@ -482,7 +482,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
#pragma unroll
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
frag_p[ki][mi].hrelu_();
frag_p[ki][mi].template hrelu_<__half>();
}
}
}
@ -509,7 +509,6 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
// TODO: not sure if this is right for seqlen 128 or 256
int rows[Gmem_tile_o::STGS_PER_LOOP];
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
@ -606,7 +605,7 @@ inline __device__ void device_1xN_(const Params &params, const int bidb, const i
// Output the values.
if (is_final_write) {
gmem_o.store(out, 0);
gmem_o.template store<__half>(out, 0);
gmem_o.move();
} else {
gmem_o_tmp.store(out, 0);