backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional * prepare the softcap switch for backwards * temporary * cleanup to the way Tri prefers * calculate dtanh when copying from scores -> dtanh Tensor * no ternary operators allowed for constexpr, so just use some hack found online * fix maybe_dtanh, restore some files * restore another file * move calculate_dtanh to utils and colocate with apply_softcap * cleanup * maybe last cleanup * save for another pr * remove a stray line * fix spacing * fix an issue, and make test_flash_attn.py ready to test softcapping backwards
This commit is contained in:
parent
ef3e358a25
commit
5f1ae4a34b
@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
@ -471,10 +471,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
|
||||
|
||||
if constexpr (Is_softcap) {
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
// Softcapping - calculating dTanh and scaling dS later with it
|
||||
auto dtanh = ([&]{
|
||||
if constexpr (Is_softcap) {
|
||||
Tensor _dtanh = make_tensor_like(scores);
|
||||
flash::calculate_dtanh(scores, _dtanh, params.softcap);
|
||||
return _dtanh;
|
||||
}
|
||||
else {
|
||||
return nullptr;
|
||||
}
|
||||
}());
|
||||
|
||||
// Alibi
|
||||
if (Has_alibi) {
|
||||
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
|
||||
@ -574,7 +591,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
for (int mi = 0; mi < size<0>(dS); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(dS); ++ni) {
|
||||
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
||||
|
||||
float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
||||
|
||||
if constexpr (Is_softcap) {
|
||||
scaled_ds *= dtanh(mi, ni);
|
||||
}
|
||||
|
||||
dS(mi, ni) = scaled_ds;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) { print(dS); }
|
||||
@ -807,7 +831,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
@ -817,7 +841,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
|
||||
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -22,14 +22,6 @@ namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
|
||||
@ -327,7 +319,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
@ -393,7 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
@ -887,7 +879,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
|
||||
@ -962,7 +954,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
if constexpr (Is_softcap){
|
||||
apply_softcap(acc_s, params.softcap);
|
||||
flash::apply_softcap(acc_s, params.softcap);
|
||||
}
|
||||
|
||||
flash::cp_async_wait<0>();
|
||||
|
||||
@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor); ++i) {
|
||||
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(src_tensor); ++i) {
|
||||
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
||||
@ -261,9 +261,9 @@ def attention_ref(
|
||||
else:
|
||||
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
|
||||
if softcap > 0:
|
||||
scores /= softcap
|
||||
scores = scores / softcap
|
||||
scores = scores.tanh()
|
||||
scores *= softcap
|
||||
scores = scores * softcap
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
|
||||
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||
@ -1122,7 +1122,7 @@ def test_flash_attn_output(
|
||||
if not alibi:
|
||||
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
|
||||
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
|
||||
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
|
||||
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
|
||||
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
|
||||
@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
|
||||
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")
|
||||
|
||||
g = torch.randn_like(out)
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
|
||||
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
|
||||
if kvpacked:
|
||||
(
|
||||
dq_unpad,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user