Remove unused sdPsum in dot_do_o function

This commit is contained in:
Tri Dao 2023-09-03 20:44:07 -07:00
parent b28ec236df
commit 5953c4f58c
2 changed files with 10 additions and 16 deletions

View File

@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
const int gdP_col_stride, const float scale) {
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
if (threadIdx.x % THREADS_PER_ROW == 0) {
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
// recast<float>(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum;
}
}
}
@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params &params) {
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
// so that (dP - dP_sum) is on the same scale.
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum, dP_sum,
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
if (Clear_dQaccum) {
Tensor zero = make_fragment_like(tdQgdQaccum);
@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
// sP and sdQ share the same memory so be careful
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
if (Is_first) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}
@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
// if (!Is_first && tidx < kBlockM / 2) {
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
// if (!Is_first && tidx < kBlockM) {
// recast<float>(sdPsum)(tidx) = gdPsum(tidx);
// }
}
if (!Is_last) {
@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Is_first && m_block > m_block_min) {
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
}
@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params &params, const in
Tensor dP_sum = make_fragment_like(lse);
cute::copy(tdOrdO, tdOsdO);
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
tdOrdO, tdOrO, sdPsum, sdPsum,
tdOrdO, tdOrO, sdPsum,
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout
);
__syncthreads();

View File

@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
static constexpr int kSmemSize1colblock = kSmemQdOSize
+ (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + kSmemPSize