Make Softmax an object
This commit is contained in:
parent
4ea866ca19
commit
6f706eff96
@ -180,10 +180,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
@ -267,6 +263,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
flash::Softmax<2 * size<1>(acc_o)> softmax;
|
||||
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
@ -357,8 +355,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? flash::softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
|
||||
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
@ -435,7 +433,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
}
|
||||
|
||||
flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
@ -461,20 +459,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
|
||||
|
||||
// Convert acc_o from fp32 to fp16/bf16
|
||||
Tensor rO = flash::convert_type<Element>(acc_o);
|
||||
@ -685,11 +670,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
@ -862,6 +842,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
flash::Softmax<2 * size<1>(acc_o)> softmax;
|
||||
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
@ -939,8 +921,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? flash::softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
|
||||
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
@ -1002,7 +984,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
params.window_size_left, params.window_size_right
|
||||
);
|
||||
}
|
||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
@ -1014,21 +996,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = inv_sum;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
|
||||
// if (cute::thread0()) { print(lse); }
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
|
||||
@ -117,35 +117,64 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
flash::reduce_sum(scores, scores_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||
cute::copy(scores_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? scores_max(mi)
|
||||
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scores_sum(mi) *= scores_scale;
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
inline __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
if (Is_first) {
|
||||
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
flash::reduce_sum(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
Tensor scores_sum_cur = make_fragment_like(row_sum);
|
||||
flash::reduce_sum(scores, scores_sum_cur);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_sum); ++mi) { row_sum(mi) += scores_sum_cur(mi); }
|
||||
}
|
||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||
flash::reduce_sum(scores, scores_sum_cur);
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
inline __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||
}
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace flash
|
||||
|
||||
Loading…
Reference in New Issue
Block a user