Move softmax_rescale_o to softmax.h
This commit is contained in:
parent
6777336a1c
commit
df1418f9db
@ -25,39 +25,6 @@ using namespace cute;
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
|
||||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
|
||||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
|
||||||
if (Is_first) {
|
|
||||||
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
|
||||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
|
||||||
flash::reduce_sum(scores, scores_sum);
|
|
||||||
} else {
|
|
||||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
|
||||||
cute::copy(scores_max, scores_max_prev);
|
|
||||||
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
|
||||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
|
||||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
|
||||||
#pragma unroll
|
|
||||||
for (int mi = 0; mi < size(scores_max); ++mi) {
|
|
||||||
float scores_max_cur = !Check_inf
|
|
||||||
? scores_max(mi)
|
|
||||||
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
|
||||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
|
||||||
scores_sum(mi) *= scores_scale;
|
|
||||||
#pragma unroll
|
|
||||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
|
||||||
}
|
|
||||||
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
|
||||||
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
|
||||||
flash::reduce_sum(scores, scores_sum_cur);
|
|
||||||
#pragma unroll
|
|
||||||
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||||
|
|
||||||
@ -396,8 +363,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
|
|
||||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||||
masking_step == 0
|
masking_step == 0
|
||||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
? flash::softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
: flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
// Convert scores from fp32 to fp16/bf16
|
// Convert scores from fp32 to fp16/bf16
|
||||||
Tensor rP = flash::convert_type<Element>(scores);
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
@ -481,7 +448,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
|
|
||||||
Tensor rP = flash::convert_type<Element>(scores);
|
Tensor rP = flash::convert_type<Element>(scores);
|
||||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||||
@ -991,8 +958,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
|||||||
|
|
||||||
// We have key_padding_mask so we'll need to Check_inf
|
// We have key_padding_mask so we'll need to Check_inf
|
||||||
masking_step == 0
|
masking_step == 0
|
||||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
? flash::softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
: flash::softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||||
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
||||||
|
|
||||||
// Convert scores from fp32 to fp16/bf16
|
// Convert scores from fp32 to fp16/bf16
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/******************************************************************************
|
/******************************************************************************
|
||||||
* Copyright (c) 2023, Tri Dao.
|
* Copyright (c) 2024, Tri Dao.
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
@ -115,4 +115,37 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
|
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||||
|
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||||
|
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||||
|
if (Is_first) {
|
||||||
|
flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||||
|
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||||
|
flash::reduce_sum(scores, scores_sum);
|
||||||
|
} else {
|
||||||
|
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||||
|
cute::copy(scores_max, scores_max_prev);
|
||||||
|
flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||||
|
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||||
|
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||||
|
float scores_max_cur = !Check_inf
|
||||||
|
? scores_max(mi)
|
||||||
|
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||||
|
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||||
|
scores_sum(mi) *= scores_scale;
|
||||||
|
#pragma unroll
|
||||||
|
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||||
|
}
|
||||||
|
flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||||
|
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||||
|
flash::reduce_sum(scores, scores_sum_cur);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace flash
|
} // namespace flash
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user