Remove unused sdPsum in dot_do_o function
This commit is contained in:
parent
b28ec236df
commit
5953c4f58c
@ -73,11 +73,9 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
|||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0,
|
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||||
typename Engine1, typename Layout1, typename Engine2, typename Layout2>
|
|
||||||
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
||||||
Tensor<Engine1, Layout1> &dP_sum, Tensor<Engine2, Layout2> &sdPsum,
|
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
||||||
const int gdP_col_stride, const float scale) {
|
|
||||||
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
||||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||||
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
||||||
@ -100,7 +98,6 @@ inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engi
|
|||||||
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
dP_sum_cur = flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
||||||
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
||||||
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
||||||
// recast<float>(sdPsum)(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,7 +175,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
|||||||
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
||||||
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
||||||
// so that (dP - dP_sum) is on the same scale.
|
// so that (dP - dP_sum) is on the same scale.
|
||||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum, dP_sum,
|
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||||
if (Clear_dQaccum) {
|
if (Clear_dQaccum) {
|
||||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||||
@ -517,8 +514,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|||||||
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||||
// sP and sdQ share the same memory so be careful
|
// sP and sdQ share the same memory so be careful
|
||||||
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
|
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
|
||||||
Tensor sdPsum = make_tensor(make_smem_ptr(reinterpret_cast<float2 *>((sP.data() + cute::max(size(sP), size(sdQ))).get())),
|
|
||||||
Shape<Int<Kernel_traits::kSmemdPsumCount / 2>>{});
|
|
||||||
|
|
||||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||||
@ -733,7 +728,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|||||||
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
|
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
|
||||||
if (Is_first) {
|
if (Is_first) {
|
||||||
cute::copy(tdOrdO, tdOsdO);
|
cute::copy(tdOrdO, tdOsdO);
|
||||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
|
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -930,11 +925,6 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
|
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
|
||||||
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
|
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
|
||||||
// if (!Is_first && tidx < kBlockM / 2) {
|
|
||||||
// sdPsum(tidx) = recast<float2>(gdPsum)(tidx);
|
|
||||||
// if (!Is_first && tidx < kBlockM) {
|
|
||||||
// recast<float>(sdPsum)(tidx) = gdPsum(tidx);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!Is_last) {
|
if (!Is_last) {
|
||||||
@ -976,7 +966,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
|||||||
|
|
||||||
if (Is_first && m_block > m_block_min) {
|
if (Is_first && m_block > m_block_min) {
|
||||||
cute::copy(tdOrdO, tdOsdO);
|
cute::copy(tdOrdO, tdOsdO);
|
||||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum, sdPsum,
|
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1317,7 +1307,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
|||||||
Tensor dP_sum = make_fragment_like(lse);
|
Tensor dP_sum = make_fragment_like(lse);
|
||||||
cute::copy(tdOrdO, tdOsdO);
|
cute::copy(tdOrdO, tdOsdO);
|
||||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
|
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(
|
||||||
tdOrdO, tdOrO, sdPsum, sdPsum,
|
tdOrdO, tdOrO, sdPsum,
|
||||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout
|
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout
|
||||||
);
|
);
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|||||||
@ -321,6 +321,10 @@ struct Flash_bwd_kernel_traits : public Base {
|
|||||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||||
|
static constexpr int kSmemSize = kSmemQdOSize
|
||||||
|
+ (!Is_V_in_regs
|
||||||
|
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||||
|
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||||
+ (!Is_V_in_regs
|
+ (!Is_V_in_regs
|
||||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user