backwards for softcapping (#1033)

* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
This commit is contained in:
Phil Wang 2024-07-21 23:25:46 -07:00 committed by GitHub
parent ef3e358a25
commit 5f1ae4a34b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 69 additions and 33 deletions

View File

@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
@ -471,10 +471,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); }
// Softcapping - calculating dTanh and scaling dS later with it
auto dtanh = ([&]{
if constexpr (Is_softcap) {
Tensor _dtanh = make_tensor_like(scores);
flash::calculate_dtanh(scores, _dtanh, params.softcap);
return _dtanh;
}
else {
return nullptr;
}
}());
// Alibi
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
@ -574,7 +591,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
if constexpr (Is_softcap) {
scaled_ds *= dtanh(mi, ni);
}
dS(mi, ni) = scaled_ds;
}
}
// if (cute::thread0()) { print(dS); }
@ -807,7 +831,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the batch.
@ -817,7 +841,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

View File

@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo
#endif
}
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});

View File

@ -22,14 +22,6 @@ namespace flash {
using namespace cute;
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
@ -327,7 +319,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
mask.template apply_mask<Is_causal, Is_even_MN>(
@ -393,7 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();
@ -887,7 +879,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
@ -962,7 +954,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}
flash::cp_async_wait<0>();

View File

@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@ -261,9 +261,9 @@ def attention_ref(
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores / softcap
scores = scores.tanh()
scores *= softcap
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
@ -1122,7 +1122,7 @@ def test_flash_attn_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
g = torch.randn_like(out)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
if kvpacked:
(
dq_unpad,