Remove unused sdPsum in dot_do_o function
This commit is contained in:
parent
b28ec236df
commit
5953c4f58c
@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
|
||||
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
||||
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
|
||||
const int gdP_col_stride, const float scale) {
|
||||
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
||||
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
||||
@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
|
||||
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
||||
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
||||
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
||||
// recast<float>(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
||||
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
||||
// so that (dP - dP_sum) is on the same scale.
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum, dP_sum,
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
if (Clear_dQaccum) {
|
||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||
@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||
// sP and sdQ share the same memory so be careful
|
||||
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
|
||||
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
|
||||
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
|
||||
if (Is_first) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
|
||||
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
|
||||
// if (!Is_first && tidx < kBlockM / 2) {
|
||||
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
|
||||
// if (!Is_first && tidx < kBlockM) {
|
||||
// recast<float>(sdPsum)(tidx) = gdPsum(tidx);
|
||||
// }
|
||||
}
|
||||
|
||||
if (!Is_last) {
|
||||
@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
|
||||
if (Is_first && m_block > m_block_min) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
Tensor dP_sum = make_fragment_like(lse);
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
|
||||
tdOrdO, tdOrO, sdPsum, sdPsum,
|
||||
tdOrdO, tdOrO, sdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout
|
||||
);
|
||||
__syncthreads();
|
||||
|
||||
@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
|
||||
Loading…
Reference in New Issue
Block a user