Remove softmax fp16 max
This commit is contained in:
parent
14dc326e59
commit
050873327e
@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ __half2 apply_exp_(__half2 x, __half2 max) {
|
||||
return h2exp(__hsub2(x, max));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp2_(float x, float max) {
|
||||
return exp2f(x - max);
|
||||
// With fast-math, this produces the same PTX instruction as the assembly below
|
||||
@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ __half2 apply_exp2_(__half2 x, __half2 max) {
|
||||
return h2exp2(__hsub2(x, max));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int COLS, bool half> struct ReadType {};
|
||||
template<> struct ReadType<4, false> { using T = float;};
|
||||
template<> struct ReadType<8, false> { using T = float2;};
|
||||
template<> struct ReadType<4, true> { using T = __half2;};
|
||||
template<> struct ReadType<8, true> { using T = float2;};
|
||||
template<int COLS> struct ReadType {};
|
||||
template<> struct ReadType<4> { using T = float;};
|
||||
template<> struct ReadType<8> { using T = float2;};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -118,8 +104,7 @@ struct Smem_tile_reduce {
|
||||
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
|
||||
static_assert(LOOPS == 1);
|
||||
|
||||
using read_t = typename ReadType<COLS, /*half=*/false>::T;
|
||||
using read_half_t = typename ReadType<COLS, /*half=*/true>::T;
|
||||
using read_t = typename ReadType<COLS>::T;
|
||||
|
||||
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
|
||||
|
||||
@ -152,17 +137,6 @@ struct Smem_tile_reduce {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void store(__half2 (&frag)[MMAS_M]) {
|
||||
__half2 *smem_write_half_ = reinterpret_cast<__half2 *>(smem_write_);
|
||||
if( qid_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * WARPS_N;
|
||||
smem_write_half_[offset + 0 * 8 * WARPS_N] = frag[mi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
@ -172,15 +146,6 @@ struct Smem_tile_reduce {
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load(read_half_t (&frag)[MMAS_M]) {
|
||||
read_half_t *smem_read_half_ = reinterpret_cast<read_half_t *>(smem_read_);
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi] = smem_read_half_[offset + 0 * 8 * 4];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
@ -304,29 +269,6 @@ struct Softmax_base {
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
inline __device__ void apply_exp(const __half2 (&max)[MMAS_M]) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
const float2 max_f = __half22float2(max[mi]);
|
||||
const float max0_log2e = max_f.x * kLog2e, max1_log2e = max_f.y * kLog2e;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < MMAS_N * 4; ++ni) {
|
||||
float2 elt = __half22float2(elt_half_[mi][ni]);
|
||||
elt_[mi * 2 + 0][ni] = apply_exp2_(elt.x * kLog2e, max0_log2e);
|
||||
elt_[mi * 2 + 1][ni] = apply_exp2_(elt.y * kLog2e, max1_log2e);
|
||||
// __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]);
|
||||
// float2 outf = __half22float2(out);
|
||||
// elt_[mi * 2 + 0][ni] = outf.x;
|
||||
// elt_[mi * 2 + 1][ni] = outf.y;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool max_in_base2=false>
|
||||
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
|
||||
@ -527,7 +469,6 @@ struct Softmax_base {
|
||||
int tidx_;
|
||||
// The elements.
|
||||
float elt_[MMAS_M * 2][MMAS_N * 4];
|
||||
__half2 elt_half_[MMAS_M][MMAS_N * 4];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
}
|
||||
}
|
||||
|
||||
// Scale FP32 fragments
|
||||
template <typename Mask>
|
||||
inline __device__ void unpack_noscale_half_and_apply_mask(const Accumulator (&acc)[MMAS_M][MMAS_N],
|
||||
const Mask &mask) {
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
float tmp[2][4];
|
||||
// 1st row - 4 elements per row.
|
||||
tmp[0][0] = mask.is_valid(mi, ni, 0, 0) ? acc[mi][ni].elt(0) : -INFINITY;
|
||||
tmp[0][1] = mask.is_valid(mi, ni, 0, 1) ? acc[mi][ni].elt(1) : -INFINITY;
|
||||
tmp[0][2] = mask.is_valid(mi, ni, 0, 2) ? acc[mi][ni].elt(4) : -INFINITY;
|
||||
tmp[0][3] = mask.is_valid(mi, ni, 0, 3) ? acc[mi][ni].elt(5) : -INFINITY;
|
||||
// 2nd row - 4 elements per row.
|
||||
tmp[1][0] = mask.is_valid(mi, ni, 1, 0) ? acc[mi][ni].elt(2) : -INFINITY;
|
||||
tmp[1][1] = mask.is_valid(mi, ni, 1, 1) ? acc[mi][ni].elt(3) : -INFINITY;
|
||||
tmp[1][2] = mask.is_valid(mi, ni, 1, 2) ? acc[mi][ni].elt(6) : -INFINITY;
|
||||
tmp[1][3] = mask.is_valid(mi, ni, 1, 3) ? acc[mi][ni].elt(7) : -INFINITY;
|
||||
this->elt_half_[mi][4 * ni + 0] = __floats2half2_rn(tmp[0][0], tmp[1][0]);
|
||||
this->elt_half_[mi][4 * ni + 1] = __floats2half2_rn(tmp[0][1], tmp[1][1]);
|
||||
this->elt_half_[mi][4 * ni + 2] = __floats2half2_rn(tmp[0][2], tmp[1][2]);
|
||||
this->elt_half_[mi][4 * ni + 3] = __floats2half2_rn(tmp[0][3], tmp[1][3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
|
||||
#pragma unroll
|
||||
@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Operator>
|
||||
__device__ inline void thread_reduce_(__half2 (&frag)[MMAS_M], Operator &op) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
frag[mi] = this->elt_half_[mi][0];
|
||||
#pragma unroll
|
||||
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
|
||||
frag[mi] = op(frag[mi], this->elt_half_[mi][ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
|
||||
thread_reduce_<zero_init>(frag, op);
|
||||
@ -701,29 +602,13 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
quad_allreduce(frag, tmp, op);
|
||||
}
|
||||
|
||||
template<typename Operator>
|
||||
__device__ inline void reduce_(__half2 (&frag)[MMAS_M], Operator &op, Smem_tile_red & smem_red) {
|
||||
thread_reduce_(frag, op);
|
||||
quad_reduce(frag, frag, op);
|
||||
smem_red.store(frag);
|
||||
__syncthreads();
|
||||
typename Smem_tile_red::read_half_t tmp[MMAS_M];
|
||||
smem_red.load(tmp);
|
||||
quad_allreduce(frag, tmp, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
|
||||
MaxOp<float> max;
|
||||
reduce_<zero_init>(frag, max, smem_max_);
|
||||
}
|
||||
|
||||
__device__ inline void reduce_max(__half2 (&frag)[MMAS_M]){
|
||||
MaxOp<__half2> max;
|
||||
reduce_(frag, max, smem_max_);
|
||||
}
|
||||
|
||||
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
|
||||
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
reduce_(frag, sum, smem_sum_);
|
||||
}
|
||||
|
||||
@ -1024,11 +1024,6 @@ struct MaxOp<float> {
|
||||
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<__half2> {
|
||||
__device__ inline __half2 operator()(__half2 const &x, __half2 const &y) { return __hmax2(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user