Clean up softcapping bwd a bit
This commit is contained in:
parent
751c762c9c
commit
5ca83a9c71
@ -353,7 +353,7 @@ Thanks to @beginlner for this contribution.
|
||||
### 2.6: Softcapping.
|
||||
|
||||
Support attention with softcapping, as used in Gemma-2 and Grok models.
|
||||
Thanks to @Narsil for this contribution.
|
||||
Thanks to @Narsil and @lucidrains for this contribution.
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@ -480,16 +480,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
// Softcapping - calculating dTanh and scaling dS later with it
|
||||
auto dtanh = ([&]{
|
||||
if constexpr (Is_softcap) {
|
||||
Tensor _dtanh = make_tensor_like(scores);
|
||||
flash::calculate_dtanh(scores, _dtanh, params.softcap);
|
||||
return _dtanh;
|
||||
}
|
||||
else {
|
||||
return nullptr;
|
||||
}
|
||||
}());
|
||||
Tensor dtanh = make_tensor_like(scores);
|
||||
if constexpr (Is_softcap) {
|
||||
flash::calculate_dtanh(scores, dtanh, params.softcap);
|
||||
}
|
||||
|
||||
// Alibi
|
||||
if (Has_alibi) {
|
||||
@ -591,13 +585,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
for (int mi = 0; mi < size<0>(dS); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(dS); ++ni) {
|
||||
|
||||
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
||||
|
||||
if constexpr (Is_softcap) {
|
||||
scaled_ds *= dtanh(mi, ni);
|
||||
}
|
||||
|
||||
if constexpr (Is_softcap) { scaled_ds *= dtanh(mi, ni); }
|
||||
dS(mi, ni) = scaled_ds;
|
||||
}
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user