Remove softmax fp16 max

This commit is contained in:
Tri Dao 2022-06-02 14:09:46 -07:00
parent 14dc326e59
commit 050873327e
2 changed files with 5 additions and 125 deletions

View File

@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp_(__half2 x, __half2 max) {
return h2exp(__hsub2(x, max));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp2_(float x, float max) {
return exp2f(x - max);
// With fast-math, this produces the same PTX instruction as the assembly below
@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ __half2 apply_exp2_(__half2 x, __half2 max) {
return h2exp2(__hsub2(x, max));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS, bool half> struct ReadType {};
template<> struct ReadType<4, false> { using T = float;};
template<> struct ReadType<8, false> { using T = float2;};
template<> struct ReadType<4, true> { using T = __half2;};
template<> struct ReadType<8, true> { using T = float2;};
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -118,8 +104,7 @@ struct Smem_tile_reduce {
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS, /*half=*/false>::T;
using read_half_t = typename ReadType<COLS, /*half=*/true>::T;
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
@ -152,17 +137,6 @@ struct Smem_tile_reduce {
}
}
__device__ inline void store(__half2 (&frag)[MMAS_M]) {
__half2 *smem_write_half_ = reinterpret_cast<__half2 *>(smem_write_);
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_half_[offset + 0 * 8 * WARPS_N] = frag[mi];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
@ -172,15 +146,6 @@ struct Smem_tile_reduce {
}
}
__device__ inline void load(read_half_t (&frag)[MMAS_M]) {
read_half_t *smem_read_half_ = reinterpret_cast<read_half_t *>(smem_read_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_half_[offset + 0 * 8 * 4];
}
}
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
@ -304,29 +269,6 @@ struct Softmax_base {
}
}
// Apply the exp to all the elements.
inline __device__ void apply_exp(const __half2 (&max)[MMAS_M]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr float kLog2e = M_LOG2E;
const float2 max_f = __half22float2(max[mi]);
const float max0_log2e = max_f.x * kLog2e, max1_log2e = max_f.y * kLog2e;
#pragma unroll
for (int ni = 0; ni < MMAS_N * 4; ++ni) {
float2 elt = __half22float2(elt_half_[mi][ni]);
elt_[mi * 2 + 0][ni] = apply_exp2_(elt.x * kLog2e, max0_log2e);
elt_[mi * 2 + 1][ni] = apply_exp2_(elt.y * kLog2e, max1_log2e);
// __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]);
// float2 outf = __half22float2(out);
// elt_[mi * 2 + 0][ni] = outf.x;
// elt_[mi * 2 + 1][ni] = outf.y;
}
}
}
// Apply the exp to all the elements.
template <bool max_in_base2=false>
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
@ -527,7 +469,6 @@ struct Softmax_base {
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
__half2 elt_half_[MMAS_M][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
// Scale FP32 fragments
template <typename Mask>
inline __device__ void unpack_noscale_half_and_apply_mask(const Accumulator (&acc)[MMAS_M][MMAS_N],
const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
float tmp[2][4];
// 1st row - 4 elements per row.
tmp[0][0] = mask.is_valid(mi, ni, 0, 0) ? acc[mi][ni].elt(0) : -INFINITY;
tmp[0][1] = mask.is_valid(mi, ni, 0, 1) ? acc[mi][ni].elt(1) : -INFINITY;
tmp[0][2] = mask.is_valid(mi, ni, 0, 2) ? acc[mi][ni].elt(4) : -INFINITY;
tmp[0][3] = mask.is_valid(mi, ni, 0, 3) ? acc[mi][ni].elt(5) : -INFINITY;
// 2nd row - 4 elements per row.
tmp[1][0] = mask.is_valid(mi, ni, 1, 0) ? acc[mi][ni].elt(2) : -INFINITY;
tmp[1][1] = mask.is_valid(mi, ni, 1, 1) ? acc[mi][ni].elt(3) : -INFINITY;
tmp[1][2] = mask.is_valid(mi, ni, 1, 2) ? acc[mi][ni].elt(6) : -INFINITY;
tmp[1][3] = mask.is_valid(mi, ni, 1, 3) ? acc[mi][ni].elt(7) : -INFINITY;
this->elt_half_[mi][4 * ni + 0] = __floats2half2_rn(tmp[0][0], tmp[1][0]);
this->elt_half_[mi][4 * ni + 1] = __floats2half2_rn(tmp[0][1], tmp[1][1]);
this->elt_half_[mi][4 * ni + 2] = __floats2half2_rn(tmp[0][2], tmp[1][2]);
this->elt_half_[mi][4 * ni + 3] = __floats2half2_rn(tmp[0][3], tmp[1][3]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll
@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
template<typename Operator>
__device__ inline void thread_reduce_(__half2 (&frag)[MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
frag[mi] = this->elt_half_[mi][0];
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_half_[mi][ni]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_<zero_init>(frag, op);
@ -701,29 +602,13 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
quad_allreduce(frag, tmp, op);
}
template<typename Operator>
__device__ inline void reduce_(__half2 (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_half_t tmp[MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<bool zero_init=true>
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_<zero_init>(frag, max, smem_max_);
}
__device__ inline void reduce_max(__half2 (&frag)[MMAS_M]){
MaxOp<__half2> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}

View File

@ -1024,11 +1024,6 @@ struct MaxOp<float> {
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
template <>
struct MaxOp<__half2> {
__device__ inline __half2 operator()(__half2 const &x, __half2 const &y) { return __hmax2(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>