2024-07-12 00:53:36 +08:00
/******************************************************************************
* Copyright ( c ) 2024 , Jay Shah , Ganesh Bikshandi , Ying Zhang , Vijay Thakkar , Pradeep Ramani , Tri Dao .
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
# pragma once
# include <cmath>
# include <cute/tensor.hpp>
# include <cutlass/numeric_types.h>
# include "utils.h"
2024-07-31 05:14:14 +08:00
# include "cutlass/fast_math.h"
2024-07-12 00:53:36 +08:00
namespace flash {
using namespace cute ;
////////////////////////////////////////////////////////////////////////////////////////////////////
template < bool zero_init = true , typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 , typename Operator >
__device__ __forceinline__ void thread_reduce_ ( Tensor < Engine0 , Layout0 > const & tensor , Tensor < Engine1 , Layout1 > & summary , Operator & op ) {
static_assert ( Layout0 : : rank = = 2 , " Only support 2D Tensor " ) ;
static_assert ( Layout1 : : rank = = 1 , " Only support 1D Tensor " ) ;
CUTE_STATIC_ASSERT_V ( size < 0 > ( summary ) = = size < 0 > ( tensor ) ) ;
# pragma unroll
for ( int mi = 0 ; mi < size < 0 > ( tensor ) ; mi + + ) {
summary ( mi ) = zero_init ? tensor ( mi , 0 ) : op ( summary ( mi ) , tensor ( mi , 0 ) ) ;
# pragma unroll
for ( int ni = 1 ; ni < size < 1 > ( tensor ) ; ni + + ) {
summary ( mi ) = op ( summary ( mi ) , tensor ( mi , ni ) ) ;
}
}
}
template < typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 , typename Operator >
__device__ __forceinline__ void quad_allreduce_ ( Tensor < Engine0 , Layout0 > & dst , Tensor < Engine1 , Layout1 > & src , Operator & op ) {
CUTE_STATIC_ASSERT_V ( size ( dst ) = = size ( src ) ) ;
# pragma unroll
for ( int i = 0 ; i < size ( dst ) ; i + + ) {
dst ( i ) = Allreduce < 4 > : : run ( src ( i ) , op ) ;
}
}
template < bool zero_init = true , typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 , typename Operator >
__device__ __forceinline__ void reduce_ ( Tensor < Engine0 , Layout0 > const & tensor , Tensor < Engine1 , Layout1 > & summary , Operator & op ) {
thread_reduce_ < zero_init > ( tensor , summary , op ) ;
quad_allreduce_ ( summary , summary , op ) ;
}
template < bool zero_init = true , typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 >
__device__ __forceinline__ void reduce_max ( Tensor < Engine0 , Layout0 > const & tensor , Tensor < Engine1 , Layout1 > & max ) {
MaxOp < float > max_op ;
reduce_ < zero_init > ( tensor , max , max_op ) ;
}
template < bool zero_init = true , bool warp_reduce = true , typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 >
__device__ __forceinline__ void reduce_sum ( Tensor < Engine0 , Layout0 > const & tensor , Tensor < Engine1 , Layout1 > & sum ) {
SumOp < float > sum_op ;
thread_reduce_ < zero_init > ( tensor , sum , sum_op ) ;
if constexpr ( warp_reduce ) { quad_allreduce_ ( sum , sum , sum_op ) ; }
}
__forceinline__ __device__ __half2 half_exp ( __half2 x ) {
uint32_t tmp_out , tmp_in ;
tmp_in = reinterpret_cast < uint32_t & > ( x ) ;
asm ( " ex2.approx.f16x2 %0, %1; \n "
: " =r " ( tmp_out )
: " r " ( tmp_in ) ) ;
__half2 out = reinterpret_cast < __half2 & > ( tmp_out ) ;
return out ;
}
// Apply the exp to all the elements.
template < bool zero_init = false , typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 >
__forceinline__ __device__ void max_scale_exp2_sum ( Tensor < Engine0 , Layout0 > & tensor , Tensor < Engine1 , Layout1 > & max , Tensor < Engine1 , Layout1 > & sum , const float scale ) {
static_assert ( Layout0 : : rank = = 2 , " Only support 2D Tensor " ) ; static_assert ( Layout1 : : rank = = 1 , " Only support 1D Tensor " ) ; CUTE_STATIC_ASSERT_V ( size < 0 > ( max ) = = size < 0 > ( tensor ) ) ;
# pragma unroll
for ( int mi = 0 ; mi < size < 0 > ( tensor ) ; + + mi ) {
MaxOp < float > max_op ;
max ( mi ) = zero_init ? tensor ( mi , 0 ) : max_op ( max ( mi ) , tensor ( mi , 0 ) ) ;
# pragma unroll
for ( int ni = 1 ; ni < size < 1 > ( tensor ) ; ni + + ) {
max ( mi ) = max_op ( max ( mi ) , tensor ( mi , ni ) ) ;
}
max ( mi ) = Allreduce < 4 > : : run ( max ( mi ) , max_op ) ;
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max ( mi ) = = - INFINITY ? 0.f : max ( mi ) * scale ;
sum ( mi ) = 0 ;
# pragma unroll
for ( int ni = 0 ; ni < size < 1 > ( tensor ) ; + + ni ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor ( mi , ni ) = exp2f ( tensor ( mi , ni ) * scale - max_scaled ) ;
sum ( mi ) + = tensor ( mi , ni ) ;
}
}
}
// Apply the exp to all the elements.
2024-07-31 05:14:14 +08:00
template < bool Scale_max = true , bool Check_inf = true , bool Use_max_offset = false ,
typename Engine0 , typename Layout0 , typename Engine1 , typename Layout1 >
2024-07-12 00:53:36 +08:00
__forceinline__ __device__ void scale_apply_exp2 ( Tensor < Engine0 , Layout0 > & tensor , Tensor < Engine1 , Layout1 > const & max , const float scale ) {
2024-07-31 05:14:14 +08:00
constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f ;
2024-07-12 00:53:36 +08:00
static_assert ( Layout0 : : rank = = 2 , " Only support 2D Tensor " ) ;
static_assert ( Layout1 : : rank = = 1 , " Only support 1D Tensor " ) ;
CUTE_STATIC_ASSERT_V ( size < 0 > ( max ) = = size < 0 > ( tensor ) ) ;
# pragma unroll
for ( int mi = 0 ; mi < size < 0 > ( tensor ) ; + + mi ) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = Check_inf
2024-08-01 16:57:06 +08:00
? ( max ( mi ) = = - INFINITY ? 0.f : ( ! Scale_max ? max ( mi ) : max ( mi ) * scale ) - max_offset )
: ( ! Scale_max ? max ( mi ) : max ( mi ) * scale ) - max_offset ;
2024-07-12 00:53:36 +08:00
# pragma unroll
for ( int ni = 0 ; ni < size < 1 > ( tensor ) ; + + ni ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor ( mi , ni ) = exp2f ( tensor ( mi , ni ) * scale - max_scaled ) ;
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
2024-07-31 05:14:14 +08:00
template < int kNRows , bool Use_max_offset_ = false >
struct Softmax {
constexpr static bool Use_max_offset = Use_max_offset_ ;
// constexpr static float max_offset = Use_max_offset ? 8.0f : 0.0f;
// constexpr static float max_offset_E = max_offset * float(M_LN2);
2024-07-12 00:53:36 +08:00
using TensorT = decltype ( make_tensor < float > ( Shape < Int < kNRows > > { } ) ) ;
TensorT row_max , row_sum ;
2024-08-26 03:18:04 +08:00
const float softmax_scale_log2 ;
2024-07-12 00:53:36 +08:00
2024-08-26 03:18:04 +08:00
CUTLASS_DEVICE Softmax ( float scale_ = 1.f ) : softmax_scale_log2 ( scale_ ) { } ;
2024-07-12 00:53:36 +08:00
template < bool Is_first , bool Check_inf = false , typename Tensor0 >
2024-08-26 03:18:04 +08:00
__forceinline__ __device__ TensorT max ( Tensor0 & acc_s ) {
2024-07-12 00:53:36 +08:00
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor ( acc_s . data ( ) , flash : : convert_layout_acc_rowcol ( acc_s . layout ( ) ) ) ;
static_assert ( decltype ( size < 0 > ( scores ) ) : : value = = kNRows ) ;
TensorT scores_scale ;
if constexpr ( Is_first ) {
flash : : template reduce_max < /*zero_init=*/ true > ( scores , row_max ) ;
cute : : fill ( scores_scale , 1.f ) ;
} else {
Tensor scores_max_prev = make_fragment_like ( row_max ) ;
cute : : copy ( row_max , scores_max_prev ) ;
flash : : template reduce_max < /*zero_init=*/ false > ( scores , row_max ) ;
# pragma unroll
for ( int mi = 0 ; mi < size ( row_max ) ; + + mi ) {
float scores_max_cur = ! Check_inf
? row_max ( mi )
: ( row_max ( mi ) = = - INFINITY ? 0.0f : row_max ( mi ) ) ;
scores_scale ( mi ) = exp2f ( ( scores_max_prev ( mi ) - scores_max_cur ) * softmax_scale_log2 ) ;
row_sum ( mi ) * = scores_scale ( mi ) ;
}
}
return scores_scale ;
} ;
template < bool Is_first , bool Check_inf = false , typename Tensor0 >
2024-08-26 03:18:04 +08:00
__forceinline__ __device__ TensorT online_softmax ( Tensor0 & acc_s ) {
2024-07-12 00:53:36 +08:00
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor ( acc_s . data ( ) , flash : : convert_layout_acc_rowcol ( acc_s . layout ( ) ) ) ;
static_assert ( decltype ( size < 0 > ( scores ) ) : : value = = kNRows ) ;
TensorT scores_scale ;
if constexpr ( Is_first ) {
flash : : template reduce_max < /*zero_init=*/ true > ( scores , row_max ) ;
2024-07-31 05:14:14 +08:00
flash : : template scale_apply_exp2 < /*Scale_max=*/ true , /*Check_inf=*/ true , Use_max_offset > ( scores , row_max , softmax_scale_log2 ) ;
2024-07-12 00:53:36 +08:00
flash : : reduce_sum < /*zero_init=*/ true , /*warp_reduce=*/ false > ( scores , row_sum ) ;
cute : : fill ( scores_scale , 1.f ) ;
// if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); }
} else {
// Tensor scores_max_prev = make_fragment_like(row_max);
// cute::copy(row_max, scores_max_prev);
// flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// // if (cute::thread0()) { print_tensor(scores); printf("\n"); print_tensor(row_max); printf("\n"); }
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi) {
// float scores_max_cur = !Check_inf
// ? row_max(mi)
// : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
// scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
// row_sum(mi) *= scores_scale(mi);
// }
2024-07-31 05:14:14 +08:00
flash : : template scale_apply_exp2 < /*Scale_max=*/ true , Check_inf , Use_max_offset > ( scores , row_max , softmax_scale_log2 ) ;
2024-07-12 00:53:36 +08:00
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash : : reduce_sum < /*zero_init=*/ false , /*warp_reduce=*/ false > ( scores , row_sum ) ;
}
return scores_scale ;
} ;
2024-08-26 03:18:04 +08:00
2024-07-12 00:53:36 +08:00
template < bool Is_dropout = false , bool Split = false , typename Tensor0 >
2024-08-26 03:18:04 +08:00
__forceinline__ __device__ TensorT finalize ( Tensor0 & acc_s , float descale_v = 1.f , float rp_dropout = 1.f ) {
constexpr static float max_offset_E = Use_max_offset ? 8.f * float ( M_LN2 ) : 0.f ;
2024-07-12 00:53:36 +08:00
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor ( acc_s . data ( ) , flash : : convert_layout_acc_rowcol ( acc_s . layout ( ) ) ) ;
static_assert ( decltype ( size < 0 > ( scores ) ) : : value = = kNRows ) ;
SumOp < float > sum_op ;
quad_allreduce_ ( row_sum , row_sum , sum_op ) ;
TensorT scores_scale ;
# pragma unroll
for ( int mi = 0 ; mi < size ( row_max ) ; + + mi ) {
float sum = row_sum ( mi ) ;
2024-08-26 03:18:04 +08:00
float inv_sum = ( sum = = 0.f | | sum ! = sum ) ? 0.f : descale_v / sum ;
2024-07-31 05:14:14 +08:00
row_sum ( mi ) = ( sum = = 0.f | | sum ! = sum ) ? ( Split ? - INFINITY : INFINITY ) : ( row_max ( mi ) * softmax_scale_log2 ) * float ( M_LN2 ) - max_offset_E + __logf ( sum ) ;
2024-07-12 00:53:36 +08:00
scores_scale ( mi ) = ! Is_dropout ? inv_sum : inv_sum * rp_dropout ;
}
return scores_scale ;
} ;
template < typename Tensor1 >
__forceinline__ __device__ void rescale_o ( Tensor1 & acc_o , TensorT const & scores_scale ) {
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor ( acc_o . data ( ) , flash : : convert_layout_acc_rowcol ( acc_o . layout ( ) ) ) ;
static_assert ( decltype ( size < 0 > ( acc_o_rowcol ) ) : : value = = kNRows ) ;
# pragma unroll
for ( int mi = 0 ; mi < size ( row_max ) ; + + mi ) {
# pragma unroll
for ( int ni = 0 ; ni < size < 1 > ( acc_o_rowcol ) ; + + ni ) { acc_o_rowcol ( mi , ni ) * = scores_scale ( mi ) ; }
}
} ;
} ;
} // namespace flash