Clean up softcapping bwd a bit

This commit is contained in:
Tri Dao 2024-07-22 23:42:06 -07:00
parent 751c762c9c
commit 5ca83a9c71
3 changed files with 7 additions and 18 deletions

View File

@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution.
### 2.6: Softcapping.
Support attention with softcapping, as used in Gemma-2 and Grok models.
Thanks to @Narsil for this contribution.
Thanks to @Narsil and @lucidrains for this contribution.
## Performance

View File

@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); }
// Softcapping - calculating dTanh and scaling dS later with it
auto dtanh = ([&]{
if constexpr (Is_softcap) {
Tensor _dtanh = make_tensor_like(scores);
flash::calculate_dtanh(scores, _dtanh, params.softcap);
return _dtanh;
}
else {
return nullptr;
}
}());
Tensor dtanh = make_tensor_like(scores);
if constexpr (Is_softcap) {
flash::calculate_dtanh(scores, dtanh, params.softcap);
}
// Alibi
if (Has_alibi) {
@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) {
scaled_ds *= dtanh(mi, ni);
}
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
dS(mi, ni) = scaled_ds;
}
}

View File

@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(