Don't use Smem_dp_sum in backward pass

To reduce smem usage for SM75
This commit is contained in:
Tri Dao 2022-06-04 16:01:36 -07:00
parent b17c6fe235
commit d380e87fb6
3 changed files with 76 additions and 63 deletions

View File

@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
static_assert(smem_size_dp_sum == 16 * 4 * 2);
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
bool is_causal = params.is_causal;
@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
}
// printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params &para
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
}
// if (params.d == 64) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else {
// if( params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if( params.s >= 256 ) {
// if (dprops->major == 8 && dprops->minor == 0) {
// // Don't share smem for K & V, and don't keep V in registers
// // This speeds things up by 2-3% by avoiding register spills, but it
// // uses more shared memory, which is fine on A100 but not other GPUs.
// // For other GPUs, we keep V in registers.
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// } else if (dprops->major == 8 && dprops->minor > 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
// }
// }
// }
// }
}

View File

@ -12,16 +12,19 @@ namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Smem_dp_sum, int M>
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M],
Smem_dp_sum smem, const int buffer_idx) {
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M],
Gmem_softmax_sum gmem_softmax_d, int tidx) {
float sum[M];
fmha::SumOp<float> sum_op;
#pragma unroll
for (int mi = 0; mi < M; ++mi) {
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op);
}
const int dp_sum_row = tidx / THREADS_PER_ROW;
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<const uint32_t (&)[M]>(sum), dp_sum_row);
}
static_assert(M == 1);
smem.store(sum[0], buffer_idx);
// smem.store(sum, buffer_idx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
float dp_sum[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
float dp_sum_regs[Gmem_tile_do::LDGS];
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
if (!Is_first) { __syncthreads(); }
// Commit the data for Q, dO, and V to shared memory.
gmem_q.commit(gemm_q_k.smem_q);
gmem_do.commit(smem_do);
if (Is_first) {
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0);
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
}
gmem_softmax_d.move();
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
);
}
// Instead of scaling dP by rp_dropout, we scale V instead
@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
}
}
float dp_sum[Mma_tile_p::MMAS_M * 2];
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
// Commit the data for V to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// __syncthreads();
// }
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
// later. This is because loading dp_sum earlier uses more registers.
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
if (Is_first) {
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
} else {
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
#pragma unroll
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
#pragma unroll
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
#pragma unroll
for (int ii = 0; ii < 8; ++ii) {
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
}
for (int ii = 0; ii < 8; ++ii) {
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
}
}
}
@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
smem_kt.load(frag_kt[0], 0);
if (Is_first) {
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
const int row[2] = {quad, quad + 8};
smem_dp_sum.load(dp_sum, row, l % 2);
}
// Trigger the load for the next dO values.
if( l < steps - 1) {
smem_do.move_to_next_write_buffer();
@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
// // will be zero.
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
softmax.pack(frag_dp);
@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
if(l < steps - 1) {
gmem_do.commit(smem_do);
if (Is_first) {
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
// smem_dp_sum.move_to_next_write_buffer();
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2);
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
}
gmem_softmax_d.move();
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
);
}
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
gmem_softmax_lse.move();
if (!Is_first) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
}
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params &params, Prng
// Make sure dQ is in shared memory.
__syncthreads();
if (l < steps - 1) {
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
gmem_softmax_d.move();
}
// Load from shared memory.
smem_dq.template load</*zero_init=*/Is_first>(dq_out);

View File

@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
}
// if (launch_params.params.d == 64) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// if (launch_params.params.d == 64) {
// if( launch_params.params.s == 128 ) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if( launch_params.params.s >= 256 ) {
// auto dprops = at::cuda::getCurrentDeviceProperties();
// if (dprops->major == 8 && dprops->minor >= 0) {
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// } else if (dprops->major == 7 && dprops->minor == 5) {
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
// }
// }
// }
}