/* Copyright (c) 2022, Tri Dao. */ #pragma once #include "fmha_fprop_kernel_1xN.h" #include "fmha_kernel.h" #include #include namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], Gmem_softmax_sum gmem_softmax_d, int tidx) { float sum[M]; fmha::SumOp sum_op; #pragma unroll for (int mi = 0; mi < M; ++mi) { sum[mi] = fmha::Allreduce::run(fmha::hmulsum8(do_[mi], o[mi]), sum_op); } const int dp_sum_row = tidx / THREADS_PER_ROW; if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) { gmem_softmax_d.store_row(reinterpret_cast(sum), dp_sum_row); } } //////////////////////////////////////////////////////////////////////////////////////////////////// template inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph, const int loop_step_idx) { // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. using Cta_tile_dq = typename Kernel_traits::Cta_tile_o; // The description of the CTA tile for the 3rd batched GEMM. using Cta_tile_dkv = fmha::Cta_tile_extd; static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); static_assert(Cta_tile_dkv::K == 16); // The MMA tile for the 1st GEMM. using Mma_tile_p = fmha::Hmma_tile; // The MMA tile for the 2nd GEMM. using Mma_tile_dq = fmha::Hmma_tile; // The MMA tile for the 3rd GEMM. using Mma_tile_dkv = fmha::Hmma_tile; // The global memory tile to load Q. using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q; // The shared memory tile to reload Q transposed. using Smem_tile_qt = fmha::Smem_tile_b; // The global memory tile to load K. using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k; // The shared memory tile to swizzle K^T. Treat K^T as V using Smem_tile_kt = typename Kernel_traits::Smem_tile_v; // Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong // The global memory tile to load V. using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k; // The shared memory tile to swizzle V. using Smem_tile_v = typename Kernel_traits::Smem_tile_k; // The global memory tile to load dO. using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do; // The shared memory tile to load dO. // Treating dO as Q. using Smem_tile_do = typename Kernel_traits::Smem_tile_q; // The shared memory tile to reload dO transposed. using Smem_tile_dot = fmha::Smem_tile_b; // The global memory tile to load O.Loading O here is similar to loading dO. using Gmem_tile_o = Gmem_tile_do; // The global memory tile to store dQ. using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; // The shared memory tile to swizzle dQ. using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; // The global memory tile to store dV. using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle dV. using Smem_tile_dv = fmha::Smem_tile_mma_epilogue; // The global memory tile to store dK. using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v; // The shared memory tile to swizzle dK. using Smem_tile_dk = fmha::Smem_tile_mma_epilogue; static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS); static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW); using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s; using Smem_tile_st = typename Kernel_traits::Smem_tile_st; using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum; // using Gemm1 = Gemm_Q_K; using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; // Shared memory. extern __shared__ char smem_[]; // Shared memory layout if we keep V in registers: // dO | Q | K / V | dQ | S | dP | dP_sum // dV | dK // Shared memory layout if we keep V shared memory: // dO | Q | K | V | dQ | S | dP | dP_sum // dV | dK // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const BlockInfoPadded binfo(params, bidb, bidh, tidx); // if( binfo.stop_early() ) return; if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return; Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for Q. Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for dQ. Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for V. Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! Smem_tile_v smem_v(smem_v_, tidx); // Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!! Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for dO. Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for dO. Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx); // Allocate the shared memory tile loader for Q^T. // TODO: assert that this points to the same memory as gemm_q_k.smem_q Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx); Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for O. Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx); static_assert(Cta_tile_p::N % Cta_tile_p::M == 0); const int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0; // constexpr int steps = Cta_tile_p::N / Cta_tile_p::M; const int steps = params.s / Cta_tile_p::M - begin; // Wind gmem tiles to the correct position. gmem_q.move(begin); gmem_do.move(begin); gmem_o.move(begin); gmem_dq.move(begin); gmem_dq_tmp.move(begin); // TODO: need to move gmem_s if we want the intermediate result for debugging gmem_softmax_lse.move(begin); gmem_softmax_d.move(begin); if (!Is_first) { gmem_k.move(loop_step_idx); gmem_v.move(loop_step_idx); } // Trigger the loads for K. gmem_k.load(); // Trigger the loads for Q. gmem_q.load(); // Trigger the loads for V. gmem_v.load(); // Trigger the loads for dO. gmem_do.load(); // Trigger the loads for O. if (Is_first) { gmem_o.load(); } float p_lse[Mma_tile_p::MMAS_M * 2]; gmem_softmax_lse.load(reinterpret_cast(p_lse)); gmem_softmax_lse.move(); if (!Is_first) { __syncthreads(); } // Commit the data for Q, dO, and V to shared memory. gmem_q.commit(gemm_q_k.smem_q); gmem_do.commit(smem_do); if (Is_first) { dot_do_o( gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx ); } // Instead of scaling dP by rp_dropout, we scale V instead if (Is_dropout) { const uint32_t scale_dropout = params.scale_dropout; #pragma unroll for(int it=0; it < Gmem_tile_v::LDGS; it++){ gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); } } gmem_v.commit(smem_v); // const uint32_t scale_bmm1 = reinterpret_cast(params.scale_bmm1); // #pragma unroll // for(int it=0; it < Gmem_tile_k::LDGS; it++){ // gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]); // } // Commit the data for K to shared memory. if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { gmem_k.commit(gemm_q_k.smem_k); } __syncthreads(); // Load the fragments for Q. gemm_q_k.load_q(); // Load the fragments for V. We keep the data in registers during the entire kernel. typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N]; if (Kernel_traits::V_IN_REGS) { #pragma unroll for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) { smem_v.load(frag_v[ki], ki); } } float dp_sum[Mma_tile_p::MMAS_M * 2]; gmem_softmax_d.load(reinterpret_cast(dp_sum)); gmem_softmax_d.move(); // Commit the data for V to shared memory if it has not been done already. if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) { // Make sure we are done loading the fragments for K. __syncthreads(); // Commit the data to shared memory for V. gmem_k.commit(gemm_q_k.smem_k); // Make sure the data is in shared memory. __syncthreads(); } // Load the fragments for K. gemm_q_k.load_k(); // Load the fragments for K^T. // typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; // smem_kt.load(frag_kt[0], 0); // typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N]; // #pragma unroll // for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) { // smem_kt.load(frag_kt[ki], ki); // } // Create the object to do the softmax. // We won't be using the shared memory for this softmax at all Softmax softmax(params, smem_, tidx); // Declare the accumulators for the 3rd gemm. fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; fmha::Clear_accumulator::apply(acc_dv); fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N]; fmha::Clear_accumulator::apply(acc_dk); // Load over the entire sequence length. for( int l = 0; l < steps; l++ ) { const int loop = (begin + l) * Cta_tile_p::M; if( loop >= binfo.actual_seqlen ) break; // Load the fragments for V. // typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N]; if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); } // Load the fragments for dO. typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M]; smem_do.load(frag_do[0], 0); // Declare the accumulators for the 1st gemm. fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; fmha::Clear_accumulator::apply(acc_p); // Do this part of P^T = (Q * K^T)^T. gemm_q_k(acc_p); // Load the mask for that iteration. mask.load(begin + l); // Convert from the accumulator type to FP32 for Softmax. softmax.unpack_noscale(acc_p); // Apply the mask. softmax.apply_mask(mask); // Scale by log-sum-exp of the softmax // softmax.apply_exp(p_lse); softmax.template scale_apply_exp(p_lse, params.scale_bmm1f); if (Is_dropout) { // softmax.apply_dropout(ph, params.p_dropout_in_uint); // softmax.template apply_dropout(ph, params.p_dropout_in_uint); softmax.template apply_dropout_16bits(ph, params.p_dropout_in_uint16_t); } using Frag_p = fmha::Fragment_a; Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); softmax.pack(frag_p); // Store s * dmask to smem for transpose smem_s.store(frag_p); // Trigger the load for the next Q values. if( l < steps - 1) { gemm_q_k.smem_q.move_to_next_write_buffer(); gmem_q.move(); gmem_q.load(); } // if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) { // // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction // __syncthreads(); // } fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N]; #pragma unroll for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) { #pragma unroll for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) { #pragma unroll for (int ii = 0; ii < 8; ++ii) { acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)]; } } } // Do this part of dP^T = (dO * V^T)^T. #pragma unroll for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) { // Trigger the load from shared memory for the next series of dO values. smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); // printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y); // tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1])); // printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y); // } } // Do the final stage of math. { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } // Load the fragments for K^T. typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N]; smem_kt.load(frag_kt[0], 0); // Trigger the load for the next dO values. if( l < steps - 1) { smem_do.move_to_next_write_buffer(); gmem_do.move(); gmem_do.load(); if (Is_first) { gmem_o.move(); gmem_o.load(); } } softmax.unpack_noscale(acc_dp); // // TD [2022-04-01]: Don't need to apply mask since the corresponding value in softmax // // will be zero. // for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum[mi] *= params.p_dropout; } Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; softmax.pack(frag_dp); if (!Is_dropout) { #pragma unroll for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { frag_p[mi][ni].hmul(frag_dp[mi][ni]); } } } else { __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); } const __half zero_h = __half(0.f); #pragma unroll for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { #pragma unroll for (int ii = 0; ii < 4; ++ii) { const __half2 p = frag_p[mi][ni].template elt_as<__half2>(ii); const __half2 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__half2>(ii)); // If this element is dropped, then frag_p stores -p instead of p. // So pd holds -p * dp_sum in that case. const __half2 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); const __half low = __low2half(p) >= zero_h ? __low2half(pdp) : __low2half(pd); const __half high = __high2half(p) >= zero_h ? __high2half(pdp) : __high2half(pd); frag_p[mi][ni].template elt_as<__half2>(ii) = __halves2half2(low, high); } } } } // Store dp to smem for transpose smem_dp.store(frag_p); // gmem_s.store(frag_p, mask); // gmem_s.move(); // Declare the accumulators for the 2nd gemm. fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N]; fmha::Clear_accumulator::apply(acc_dq); // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) { // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } static_assert(Gmem_tile_dq::LOOPS == 1); // Swizzle the elements and do the final reduction. // Need to syncthreads here, otherwise the smem_dq reads from the previous iteration // might happen after the smem_dq writes in this iteration. __syncthreads(); smem_dq.store(acc_dq, 0); typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N]; static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4); static_assert(Mma_tile_dkv::MMAS_K == 1); smem_dot.load(frag_dot[0], 0); // Threads in a warp is communicating via shared memory (smem_s and smem_dp) __syncwarp(); typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; smem_s.load(frag_s); if (Is_dropout) { #pragma unroll for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { frag_s[ki][mi].hrelu_(); } } } #pragma unroll for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0])); // printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y); // float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1])); // printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y); // } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); // printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); // } // __syncthreads(); // Commit the values for Q and dO into shared memory. if(l < steps - 1) { gmem_q.commit(gemm_q_k.smem_q); } uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP]; if (!Is_first) { gmem_dq_tmp.load(dq_out, 0); } // __syncthreads(); // Commit the values for Q and dO into shared memory. if(l < steps - 1) { gmem_do.commit(smem_do); if (Is_first) { dot_do_o( gmem_do.fetch_, gmem_o.fetch_, gmem_softmax_d, tidx ); } gmem_softmax_lse.load(reinterpret_cast(p_lse)); gmem_softmax_lse.move(); } typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M]; smem_dp.load(frag_dpt); gemm_q_k.reload_k(); typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N]; static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4); static_assert(Mma_tile_dkv::MMAS_K == 1); smem_qt.load(frag_qt[0], 0); #pragma unroll for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) { // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. __syncthreads(); if (l < steps - 1) { gmem_softmax_d.load(reinterpret_cast(dp_sum)); gmem_softmax_d.move(); } // Load from shared memory. smem_dq.template load(dq_out); const bool is_final_write = Is_last || ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen) || ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N)); if (is_final_write) { // if (Is_dropout) { // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); // } for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); } // Output the values. gmem_dq.store(dq_out, 0); // Move to the next part of the output. gmem_dq.move(); } else { // Output the values. gmem_dq_tmp.store(dq_out, 0); } // Move to the next part of the output. if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); } // // Make sure the data is in shared memory. // __syncthreads(); // Commit the values for Q and dO into shared memory. if(l < steps - 1) { gemm_q_k.smem_q.move_to_next_read_buffer(); gemm_q_k.reload_q(); smem_qt.move_to_next_read_buffer(); // smem_qt.load(frag_qt[0], 0); smem_do.move_to_next_read_buffer(); smem_dot.move_to_next_read_buffer(); // smem_dot.load(frag_dot[0], 0); } } // Outer loop over the sequence length. if (Is_dropout) { for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { acc_dv[mi][ni].mul_(params.rp_dropout); } } } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3)); // printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3)); // } for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) { // acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f); acc_dk[mi][ni].mul_(params.scale_bmm1f); } } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1)); // } __syncthreads(); // TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than // the total amount of shared mem? // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[0], tidx); smem_dv.store(acc_dv); // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); smem_dk.store(acc_dk); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; smem_dv.load(dv_out); Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dv.move(loop_step_idx); } gmem_dv.store(dv_out); uint4 dk_out[Smem_tile_dk::NUM_LDS]; smem_dk.load(dk_out); // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // } Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dk.move(loop_step_idx); } gmem_dk.store(dk_out); } //////////////////////////////////////////////////////////////////////////////////////////////////// // loop_steps = -1 means the number of steps will be params.s / Kernel_traits::Cta_tile_p::N. // This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2. template inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) { constexpr int N_per_loop = Kernel_traits::Cta_tile_p::N; // The block index for the batch. const int bidb = blockIdx.y; // The block index for the head. const int bidh = blockIdx.x; // The thread index. const int tidx = threadIdx.x; const int tidx_global = (bidb * params.h + bidh) * blockDim.x + tidx; auto seeds = at::cuda::philox::unpack(params.philox_args); Philox ph(std::get<0>(seeds), tidx_global, std::get<1>(seeds)); if (loop_steps == 1) { compute_dq_dk_dv_1xN_one_iter(params, ph, 0); } else if (loop_steps == 2) { compute_dq_dk_dv_1xN_one_iter(params, ph, 0); compute_dq_dk_dv_1xN_one_iter(params, ph, 1); } else { if (params.s == N_per_loop) { compute_dq_dk_dv_1xN_one_iter(params, ph, 0); } else { const int max_loop_steps = (params.s + N_per_loop - 1) / N_per_loop; compute_dq_dk_dv_1xN_one_iter(params, ph, 0); for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) { compute_dq_dk_dv_1xN_one_iter(params, ph, loop_step_idx); } compute_dq_dk_dv_1xN_one_iter(params, ph, max_loop_steps - 1); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace fmha