Don't use Smem_dp_sum in backward pass
To reduce smem usage for SM75
This commit is contained in:
parent
b17c6fe235
commit
d380e87fb6
@ -15,15 +15,13 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
|
||||
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
|
||||
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
|
||||
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
|
||||
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
|
||||
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
|
||||
static_assert(smem_size_dp_sum == 16 * 4 * 2);
|
||||
|
||||
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum;
|
||||
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
bool is_causal = params.is_causal;
|
||||
@ -41,6 +39,7 @@ void run_fmha_dgrad_fp16_sm80_loop_(const Fused_multihead_attention_fprop_params
|
||||
: (is_causal ? &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, true, /*loop_steps=*/2> : &fmha_dgrad_fp16_sm80_dq_dk_dv_loop_kernel<Kernel_traits, false, false, /*loop_steps=*/2>);
|
||||
}
|
||||
|
||||
// printf("N = %d, WARPS_N = %d, Smem size = %d\n", N, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
@ -97,4 +96,28 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u>;
|
||||
run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
}
|
||||
// if (params.d == 64) {
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// if (dprops->major == 7 && dprops->minor == 5) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
|
||||
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
// } else {
|
||||
// if( params.s == 128 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u>;
|
||||
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
// } else if( params.s >= 256 ) {
|
||||
// if (dprops->major == 8 && dprops->minor == 0) {
|
||||
// // Don't share smem for K & V, and don't keep V in registers
|
||||
// // This speeds things up by 2-3% by avoiding register spills, but it
|
||||
// // uses more shared memory, which is fine on A100 but not other GPUs.
|
||||
// // For other GPUs, we keep V in registers.
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>;
|
||||
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
// } else if (dprops->major == 8 && dprops->minor > 0) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u>;
|
||||
// run_fmha_dgrad_fp16_sm80_loop_<Kernel_traits>(params, stream);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
@ -12,16 +12,19 @@ namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Smem_dp_sum, int M>
|
||||
inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M],
|
||||
Smem_dp_sum smem, const int buffer_idx) {
|
||||
template <int ROWS, int THREADS_PER_ROW, int M, typename Gmem_softmax_sum>
|
||||
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M],
|
||||
Gmem_softmax_sum gmem_softmax_d, int tidx) {
|
||||
float sum[M];
|
||||
fmha::SumOp<float> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi]));
|
||||
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op);
|
||||
}
|
||||
const int dp_sum_row = tidx / THREADS_PER_ROW;
|
||||
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<const uint32_t (&)[M]>(sum), dp_sum_row);
|
||||
}
|
||||
static_assert(M == 1);
|
||||
smem.store(sum[0], buffer_idx);
|
||||
// smem.store(sum, buffer_idx);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -101,8 +104,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum;
|
||||
|
||||
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false>;
|
||||
|
||||
@ -208,26 +209,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
gmem_softmax_lse.move();
|
||||
|
||||
float dp_sum[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
gmem_softmax_d.move();
|
||||
}
|
||||
|
||||
float dp_sum_regs[Gmem_tile_do::LDGS];
|
||||
Smem_dp_sum smem_dp_sum(reinterpret_cast<float *>(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE * 2]), tidx);
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
// Commit the data for Q, dO, and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0);
|
||||
const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW;
|
||||
if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row);
|
||||
}
|
||||
gmem_softmax_d.move();
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
|
||||
// Instead of scaling dP by rp_dropout, we scale V instead
|
||||
@ -266,6 +255,10 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
}
|
||||
}
|
||||
|
||||
float dp_sum[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
gmem_softmax_d.move();
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
@ -357,21 +350,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
// __syncthreads();
|
||||
// }
|
||||
|
||||
// TD [2022-04-24]: if Is_first, then it's faster to set acc_dp to zero then subtract by
|
||||
// dp_sum later. If !Is_first, then it's faster to set acc_dp to -dp_sum and don't subtract
|
||||
// later. This is because loading dp_sum earlier uses more registers.
|
||||
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
if (Is_first) {
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_dp);
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 8; ++ii) {
|
||||
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
|
||||
}
|
||||
for (int ii = 0; ii < 8; ++ii) {
|
||||
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -409,12 +395,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
smem_kt.load(frag_kt[0], 0);
|
||||
|
||||
if (Is_first) {
|
||||
const int quad = (tidx % Cta_tile_p::THREADS_PER_WARP) / 4;
|
||||
const int row[2] = {quad, quad + 8};
|
||||
smem_dp_sum.load(dp_sum, row, l % 2);
|
||||
}
|
||||
|
||||
// Trigger the load for the next dO values.
|
||||
if( l < steps - 1) {
|
||||
smem_do.move_to_next_write_buffer();
|
||||
@ -430,7 +410,6 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
// // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax
|
||||
// // will be zero.
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; }
|
||||
if (Is_first) { softmax.subtract_dp_sum(dp_sum); }
|
||||
|
||||
Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
softmax.pack(frag_dp);
|
||||
@ -547,21 +526,12 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
if(l < steps - 1) {
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
// dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum);
|
||||
// smem_dp_sum.move_to_next_write_buffer();
|
||||
dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2);
|
||||
const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW;
|
||||
if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<uint32_t(&)[Gmem_tile_do::LDGS]>(dp_sum_regs), dp_sum_row_1);
|
||||
}
|
||||
gmem_softmax_d.move();
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
gmem_softmax_lse.move();
|
||||
if (!Is_first) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
gmem_softmax_d.move();
|
||||
}
|
||||
}
|
||||
|
||||
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
@ -591,6 +561,11 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
|
||||
// Make sure dQ is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
if (l < steps - 1) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
gmem_softmax_d.move();
|
||||
}
|
||||
|
||||
// Load from shared memory.
|
||||
smem_dq.template load</*zero_init=*/Is_first>(dq_out);
|
||||
|
||||
|
||||
@ -120,10 +120,25 @@ void run_fmha_fp16_sm80(Launch_params<Fused_multihead_attention_fprop_params> &l
|
||||
run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
}
|
||||
// if (launch_params.params.d == 64) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// // using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
// // using Kernel_traits = FMHA_kernel_traits<64, 64, 16, 1, 4, 0x08u>;
|
||||
// // using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// }
|
||||
// if (launch_params.params.d == 64) {
|
||||
// if( launch_params.params.s == 128 ) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// } else if( launch_params.params.s >= 256 ) {
|
||||
// auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// if (dprops->major == 8 && dprops->minor >= 0) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// } else if (dprops->major == 7 && dprops->minor == 5) {
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
|
||||
// run_fmha_fp16_sm80_loop_<Kernel_traits>(launch_params, configure);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user