Refactor to template on __half, implement bf16 util functions
This commit is contained in:
parent
2dc1b205f6
commit
e518a4b327
@ -142,10 +142,11 @@ struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename elem_type>
|
||||
inline __device__ void hrelu_() {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
this->reg(ii) = fmha::hrelu2(this->reg(ii));
|
||||
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -27,6 +27,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -219,6 +221,7 @@ struct Gmem_tile_o {
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template<typename elem_type=__half>
|
||||
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
@ -237,7 +240,7 @@ struct Gmem_tile_o {
|
||||
float y = reinterpret_cast<const float &>(src[ii].y);
|
||||
float z = reinterpret_cast<const float &>(src[ii].z);
|
||||
float w = reinterpret_cast<const float &>(src[ii].w);
|
||||
uint2 out = float4_to_half4(x, y, z, w);
|
||||
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
|
||||
}
|
||||
@ -245,7 +248,7 @@ struct Gmem_tile_o {
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
// Load data from global memory.
|
||||
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
|
||||
static_assert(BYTES_PER_ELEMENT == 4);
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
@ -366,36 +369,6 @@ struct Gmem_tile_mma_s : public Base {
|
||||
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
template<typename Mask>
|
||||
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
|
||||
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
|
||||
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
|
||||
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
|
||||
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
|
||||
|
||||
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
|
||||
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
|
||||
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
|
||||
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
|
||||
|
||||
uint4 dst;
|
||||
dst.x = fmha::float2_to_half2(tmp00, tmp01);
|
||||
dst.y = fmha::float2_to_half2(tmp02, tmp03);
|
||||
dst.z = fmha::float2_to_half2(tmp10, tmp11);
|
||||
dst.w = fmha::float2_to_half2(tmp12, tmp13);
|
||||
if( mask.is_valid(mi, ni, 0, 0) ) {
|
||||
Base::store(dst, mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
template<typename Mask, typename Fragment>
|
||||
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
||||
|
||||
@ -1384,7 +1384,7 @@ struct Smem_tile_mma_epilogue : public Base {
|
||||
}
|
||||
}
|
||||
|
||||
template<int M, int N>
|
||||
template<typename elem_type=__half, int M, int N>
|
||||
inline __device__ void store(const Acc (&acc)[M][N]){
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
@ -1401,10 +1401,10 @@ struct Smem_tile_mma_epilogue : public Base {
|
||||
float tmp12 = acc[mi][ni].elt(6);
|
||||
float tmp13 = acc[mi][ni].elt(7);
|
||||
|
||||
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
|
||||
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
|
||||
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
|
||||
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
|
||||
uint32_t x = fmha::float2_pack<elem_type>(tmp00, tmp01);
|
||||
uint32_t y = fmha::float2_pack<elem_type>(tmp02, tmp03);
|
||||
uint32_t z = fmha::float2_pack<elem_type>(tmp10, tmp11);
|
||||
uint32_t w = fmha::float2_pack<elem_type>(tmp12, tmp13);
|
||||
|
||||
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
|
||||
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
|
||||
|
||||
@ -34,24 +34,6 @@ namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sum_ {
|
||||
static constexpr bool IS_SUM = true;
|
||||
static inline __device__ float apply(float x, float y) {
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Max_ {
|
||||
static constexpr bool IS_SUM = false;
|
||||
static inline __device__ float apply(float x, float y) {
|
||||
return x > y ? x : y;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp_(float x, float max) {
|
||||
return __expf(x - max);
|
||||
}
|
||||
@ -508,7 +490,7 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
}
|
||||
|
||||
// Pack the data to a fragment for the next GEMM.
|
||||
template<int K, int M>
|
||||
template<typename elem_type=__half, int K, int M>
|
||||
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
@ -528,10 +510,10 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
|
||||
|
||||
// Pack to 4 registers.
|
||||
dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
|
||||
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
|
||||
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
|
||||
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
|
||||
dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
|
||||
dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
|
||||
dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
|
||||
dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,6 +33,10 @@
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -310,12 +314,16 @@ static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
|
||||
template<typename T>
|
||||
inline __device__ uint32_t hrelu2(uint32_t x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t hrelu2<__half>(uint32_t x) {
|
||||
uint32_t res;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
|
||||
#else
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#else
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"\t .reg .f16x2 sela;\n" \
|
||||
@ -325,6 +333,19 @@ static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t habs2(uint32_t x) {
|
||||
uint32_t res;
|
||||
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
|
||||
@ -332,7 +353,7 @@ static inline __device__ uint32_t habs2(uint32_t x) {
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
|
||||
template< typename T >
|
||||
static inline __device__ T clamp(T x, T lb, T ub) {
|
||||
return x < lb ? lb : (x > ub ? ub : x);
|
||||
@ -370,6 +391,25 @@ static inline __device__ uint32_t float2_to_half2(float a, float b) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t float2_pack(float a, float b);
|
||||
|
||||
template <>
|
||||
inline __device__ uint32_t float2_pack<__half>(float a, float b) {
|
||||
__half2 result = __floats2half2_rn(a, b);
|
||||
return reinterpret_cast<uint32_t(&)>(result);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) {
|
||||
__nv_bfloat162 result = __floats2bfloat162_rn(a, b);
|
||||
return reinterpret_cast<uint32_t(&)>(result);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t float_to_half2(float a) {
|
||||
return float2_to_half2(a,a);
|
||||
}
|
||||
@ -391,6 +431,16 @@ static inline __device__ uint2 float4_to_half4(float x, float y, float z, float
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint2 float4_pack(float x, float y, float z, float w) {
|
||||
uint2 d;
|
||||
d.x = float2_pack<T>(x, y);
|
||||
d.y = float2_pack<T>(z, w);
|
||||
return d;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
@ -404,7 +454,7 @@ static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c)
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
d = hrelu2(hfma2(a, b, c));
|
||||
d = hrelu2<__half>(hfma2(a, b, c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
@ -481,32 +531,41 @@ static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two half2's into float, then take their dot product.
|
||||
// inline __device__ void hfma2_to_float(float &sum, const __half2 a, const __half2 b) {
|
||||
static inline __device__ float hfma2_to_float(const __half2 a, const __half2 b) {
|
||||
float2 af = __half22float2(a);
|
||||
float2 bf = __half22float2(b);
|
||||
template<typename T>
|
||||
inline __device__ float2 half2_unpack(uint32_t a);
|
||||
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two half2's or bf162's into float, then take their dot product.
|
||||
template <typename T>
|
||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||
float2 af = fmha::half2_unpack<T>(a);
|
||||
float2 bf = fmha::half2_unpack<T>(b);
|
||||
return af.x * bf.x + af.y * bf.y;
|
||||
// sum += af.x * bf.x + af.y * bf.y;
|
||||
// sum = __fmaf_rn(sum, af.x, bf.x);
|
||||
// sum = __fmaf_rn(sum, af.y, bf.y);
|
||||
// float2 prod = __half22float2(__hmul2(a, b));
|
||||
// sum += prod.x + prod.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two vectors of 8 half's into float, then take their dot product.
|
||||
static inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||
template<typename T>
|
||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
float sum;
|
||||
sum = fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.x),
|
||||
reinterpret_cast<const __half2&>(b.x));
|
||||
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.y),
|
||||
reinterpret_cast<const __half2&>(b.y));
|
||||
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.z),
|
||||
reinterpret_cast<const __half2&>(b.z));
|
||||
sum += fmha::hfma2_to_float(reinterpret_cast<const __half2&>(a.w),
|
||||
reinterpret_cast<const __half2&>(b.w));
|
||||
sum = fmha::hfma2_to_float<T>(a.x, b.x);
|
||||
sum += fmha::hfma2_to_float<T>(a.y, b.y);
|
||||
sum += fmha::hfma2_to_float<T>(a.z, b.z);
|
||||
sum += fmha::hfma2_to_float<T>(a.w, b.w);
|
||||
return sum;
|
||||
}
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const ui
|
||||
Smem_dp_sum smem, const int buffer_idx) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
|
||||
sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi]));
|
||||
}
|
||||
static_assert(M == 1);
|
||||
smem.store(sum[0], buffer_idx);
|
||||
@ -358,7 +358,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.pack(frag_p);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
|
||||
// Store s * dmask to smem for transpose
|
||||
smem_s.store(frag_p);
|
||||
@ -463,7 +463,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
if (is_first_read) { softmax.subtract_dp_sum(dp_sum); }
|
||||
|
||||
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
softmax.pack(frag_dp);
|
||||
softmax.template pack<__half>(frag_dp);
|
||||
|
||||
if (!Is_dropout) {
|
||||
#pragma unroll
|
||||
@ -544,7 +544,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
frag_s[ki][mi].hrelu_();
|
||||
frag_s[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -638,7 +638,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
// }
|
||||
dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f);
|
||||
// Output the values.
|
||||
gmem_dq.store(dq_out, 0);
|
||||
gmem_dq.template store<__half>(dq_out, 0);
|
||||
} else {
|
||||
// Output the values.
|
||||
gmem_dq_tmp.store(dq_out, 0);
|
||||
@ -693,11 +693,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
|
||||
// the total amount of shared mem?
|
||||
// Epilogue swizzle for dV
|
||||
Smem_tile_dv smem_dv(&smem_[0], tidx);
|
||||
smem_dv.store(acc_dv);
|
||||
smem_dv.template store<__half>(acc_dv);
|
||||
|
||||
// Epilogue swizzle for dK
|
||||
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
|
||||
smem_dk.store(acc_dk);
|
||||
smem_dk.template store<__half>(acc_dk);
|
||||
|
||||
__syncthreads();
|
||||
uint4 dv_out[Smem_tile_dv::NUM_LDS];
|
||||
|
||||
@ -335,7 +335,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
|
||||
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.pack(frag_p);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(frag_p, mask);
|
||||
if (not_last_iter) {
|
||||
@ -353,7 +353,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
|
||||
frag_p[ki][mi].hrelu_();
|
||||
frag_p[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -371,7 +371,6 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
|
||||
// The mapping from tidx to rows changes between the softmax and the O-reduction.
|
||||
// So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
// TODO: not sure if this is right for seqlen 128 or 256
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||
@ -467,7 +466,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
gmem_o.store(out, 0);
|
||||
gmem_o.template store<__half>(out, 0);
|
||||
} else {
|
||||
gmem_o_tmp.store(out, 0);
|
||||
}
|
||||
|
||||
@ -12,14 +12,16 @@ namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
|
||||
template <int ROWS, int THREADS_PER_ROW, typename elem_type=__half, int M, typename Gmem_softmax_sum>
|
||||
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
|
||||
Gmem_softmax_sum gmem_softmax_d, int tidx) {
|
||||
float sum[M];
|
||||
fmha::SumOp<float> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op) * scale;
|
||||
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(
|
||||
fmha::hmulsum8<elem_type>(do_[mi], o[mi]), sum_op
|
||||
) * scale;
|
||||
}
|
||||
const int dp_sum_row = tidx / THREADS_PER_ROW;
|
||||
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
|
||||
@ -212,7 +214,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
@ -331,7 +333,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.pack(frag_p);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
|
||||
// Store s * dmask to smem for transpose
|
||||
smem_s.store(frag_p);
|
||||
@ -422,7 +424,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
}
|
||||
}
|
||||
|
||||
softmax.pack(frag_p);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
|
||||
// Store dp to smem for transpose
|
||||
smem_dp.store(frag_p);
|
||||
@ -473,7 +475,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
frag_s[ki][mi].hrelu_();
|
||||
frag_s[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -517,7 +519,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
if(l < steps - 1) {
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, __half>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
@ -573,7 +575,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
// Output the values.
|
||||
gmem_dq.store(dq_out, 0);
|
||||
gmem_dq.template store<__half>(dq_out, 0);
|
||||
// Move to the next part of the output.
|
||||
gmem_dq.move();
|
||||
} else {
|
||||
@ -627,11 +629,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
// the total amount of shared mem?
|
||||
// Epilogue swizzle for dV
|
||||
Smem_tile_dv smem_dv(&smem_[0], tidx);
|
||||
smem_dv.store(acc_dv);
|
||||
smem_dv.template store<__half>(acc_dv);
|
||||
|
||||
// Epilogue swizzle for dK
|
||||
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
|
||||
smem_dk.store(acc_dk);
|
||||
smem_dk.template store<__half>(acc_dk);
|
||||
|
||||
__syncthreads();
|
||||
uint4 dv_out[Smem_tile_dv::NUM_LDS];
|
||||
@ -644,9 +646,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
|
||||
uint4 dk_out[Smem_tile_dk::NUM_LDS];
|
||||
smem_dk.load(dk_out);
|
||||
// for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) {
|
||||
// dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f);
|
||||
// }
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dk.move(loop_step_idx);
|
||||
@ -694,4 +693,4 @@ inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
} // namespace fmha
|
||||
|
||||
@ -466,7 +466,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.pack(frag_p);
|
||||
softmax.template pack<__half>(frag_p);
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(frag_p, mask);
|
||||
gmem_s.move();
|
||||
@ -482,7 +482,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
|
||||
frag_p[ki][mi].hrelu_();
|
||||
frag_p[ki][mi].template hrelu_<__half>();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -509,7 +509,6 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
// The mapping from tidx to rows changes between the softmax and the
|
||||
// O-reduction. So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
// TODO: not sure if this is right for seqlen 128 or 256
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||
@ -606,7 +605,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
gmem_o.store(out, 0);
|
||||
gmem_o.template store<__half>(out, 0);
|
||||
gmem_o.move();
|
||||
} else {
|
||||
gmem_o_tmp.store(out, 0);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user