From 5d07483bbcce7ee727952a8ea8425aaaecd5a451 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 12 Jun 2022 16:37:13 -0700 Subject: [PATCH] Refactor Gmem code to store q, k, v pointers separately --- csrc/flash_attn/fmha_api.cpp | 16 +- csrc/flash_attn/src/fmha.h | 21 ++- csrc/flash_attn/src/fmha/gmem_tile.h | 163 ++++-------------- csrc/flash_attn/src/fmha/kernel_traits.h | 4 +- .../src/fmha_block_dgrad_kernel_1xN_loop.h | 29 ++-- .../src/fmha_block_fprop_kernel_1xN.h | 14 +- .../src/fmha_dgrad_kernel_1xN_loop.h | 29 ++-- csrc/flash_attn/src/fmha_fprop_kernel_1xN.h | 10 +- 8 files changed, 93 insertions(+), 193 deletions(-) diff --git a/csrc/flash_attn/fmha_api.cpp b/csrc/flash_attn/fmha_api.cpp index b92b7f2..1621d67 100644 --- a/csrc/flash_attn/fmha_api.cpp +++ b/csrc/flash_attn/fmha_api.cpp @@ -56,12 +56,18 @@ void set_params(Fused_multihead_attention_fprop_params ¶ms, memset(¶ms, 0, sizeof(params)); // Set the pointers and strides. - params.qkv_ptr = qkv_packed_d; - params.qkv_stride_in_elts = h * 3 * d; - params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type); + params.q_ptr = qkv_packed_d; + params.k_ptr = qkv_packed_d + get_size_in_bytes(h * d, data_type); + params.v_ptr = qkv_packed_d + 2 * get_size_in_bytes(h * d, data_type); + params.q_row_stride_in_elts = 3 * h * d; + params.k_row_stride_in_elts = 3 * h * d; + params.v_row_stride_in_elts = 3 * h * d; + params.q_head_stride_in_elts = d; + params.k_head_stride_in_elts = d; + params.v_head_stride_in_elts = d; params.o_ptr = o_packed_d; - params.o_stride_in_elts = h * d; - params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type); + params.o_row_stride_in_elts = h * d; + params.o_head_stride_in_elts = d; params.do_ptr = do_packed_d; params.o_tmp_ptr = o_tmp_d; diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index eabacfa..35b8124 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -50,15 +50,21 @@ constexpr int D_DIM = 3; struct Qkv_params { // The QKV matrices. - void * __restrict__ qkv_ptr; + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; // The stride between rows of the Q, K and V matrices. // size_t qkv_stride_in_elts; // size_t qkv_stride_in_bytes; // TD [2022-04-16]: We're using 32-bit indexing to save registers. // The code probably won't work for arrays larger than 2GB. - uint32_t qkv_stride_in_elts; - uint32_t qkv_stride_in_bytes; + uint32_t q_row_stride_in_elts; + uint32_t k_row_stride_in_elts; + uint32_t v_row_stride_in_elts; + uint32_t q_head_stride_in_elts; + uint32_t k_head_stride_in_elts; + uint32_t v_head_stride_in_elts; // The number of heads. int h; @@ -71,17 +77,14 @@ struct Fused_multihead_attention_fprop_params : public Qkv_params { // The dQKV matrices. void * __restrict__ dqkv_ptr; - // Temporary for dKV. - void * __restrict__ dkv_ptr; - // The O matrix (output). void * __restrict__ o_ptr; // The stride between rows of O. // size_t o_stride_in_elts; // size_t o_stride_in_bytes; - uint32_t o_stride_in_elts; - uint32_t o_stride_in_bytes; + uint32_t o_row_stride_in_elts; + uint32_t o_head_stride_in_elts; // The pointer to the O_tmp matrix, which holds O intermediate value during // the loop; @@ -171,4 +174,4 @@ void run_fmha_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); \ No newline at end of file +void run_fmha_block_dgrad_fp16_sm80(const Fused_multihead_attention_fprop_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha/gmem_tile.h b/csrc/flash_attn/src/fmha/gmem_tile.h index 7d47a21..e5bcac7 100644 --- a/csrc/flash_attn/src/fmha/gmem_tile.h +++ b/csrc/flash_attn/src/fmha/gmem_tile.h @@ -39,14 +39,13 @@ template< // The number of rows of Q, K or V loaded by this tile. int ROWS_, // The number of columns. - int COLS, - // The number of matrics. - int NUM_MATS = 3 + int COLS > struct Gmem_tile_qkv { using Cta_tile = Cta_tile_; + static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8; // The size of each LDG. static constexpr int BYTES_PER_LDG = 16; // The size of a row in bytes. @@ -62,11 +61,12 @@ struct Gmem_tile_qkv { static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG); // Ctor. - template< typename Params, typename BInfo > - inline __device__ Gmem_tile_qkv(const Params ¶ms, const int qkv_offset, const BInfo &binfo, const int tidx) - : params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes) + template< typename BInfo > + inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) , actual_seqlen(binfo.actual_seqlen) - , qkv_ptr_(reinterpret_cast(params.qkv_ptr)) + , ptr(reinterpret_cast(ptr_)) , tidx_(tidx) { // Compute the position in the sequence (within the CTA for the moment). @@ -80,13 +80,13 @@ struct Gmem_tile_qkv { // The row offset in the batched GEMM. For each seq element, we store QKV in that order. // int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes; - uint32_t row_offset = (uint32_t)row * params.qkv_stride_in_bytes; + uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); // Add the block index. // row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; - row_offset += (uint32_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW; + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); // Assemble the final pointer. - qkv_ptr_ += row_offset + col * BYTES_PER_LDG; + ptr += row_offset + col * BYTES_PER_LDG; } // Store data to shared memory. @@ -101,8 +101,8 @@ struct Gmem_tile_qkv { uint32_t preds[LDGS]; #pragma unroll for( int ii = 0; ii < LDGS; ++ii ) { - // ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - ptrs[ii] = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + // ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)); fetch_[ii] = make_uint4(0, 0, 0, 0); } @@ -120,32 +120,25 @@ struct Gmem_tile_qkv { int row_ = tidx_ / THREADS_PER_ROW; #pragma unroll for( int ii = 0; ii < LDGS; ++ii ) { - // char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; - char *ptr = qkv_ptr_ + (uint32_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_; + // char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes; + char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes; if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) { - fmha::stg(ptr, data[ii]); + fmha::stg(ptr_, data[ii]); } } } - // Move the pointer to the next location. - inline __device__ void move() { - // qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_; - qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_; - actual_seqlen -= ROWS; - } - - inline __device__ void move(int steps) { - // qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps; - qkv_ptr_ += (uint32_t)ROWS * params_qkv_stride_in_bytes_ * steps; + inline __device__ void move(const int steps = 1) { + // ptr += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr += (uint32_t)ROWS * row_stride_in_bytes * steps; actual_seqlen -= ROWS * steps; } // The stride between rows for the QKV matrice. - // int64_t params_qkv_stride_in_bytes_; - uint32_t params_qkv_stride_in_bytes_; + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; // The pointer. - char *qkv_ptr_; + char *ptr; // The fetch registers. uint4 fetch_[LDGS]; // Keep track of the row the thread is processing as we move the tile. @@ -196,10 +189,10 @@ struct Gmem_tile_o { // Ctor. template - // inline __device__ Gmem_tile_o(void *ptr, const size_t stride_in_elts, const BInfo &binfo, const int tidx) - inline __device__ Gmem_tile_o(void *ptr, const uint32_t stride_in_elts, const BInfo &binfo, const int tidx) - : stride_in_bytes_(stride_in_elts * BYTES_PER_ELEMENT) - , actual_seqlen_(binfo.actual_seqlen) + // inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx) + inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts, + const uint32_t head_stride_in_elts, const BInfo &binfo, const int tidx) + : row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT) , actual_seqlen(binfo.actual_seqlen) , ptr_(reinterpret_cast(ptr)) , tidx_(tidx) { @@ -213,8 +206,9 @@ struct Gmem_tile_o { // row_ = row; // The row offset in the batched GEMM. - // int64_t row_offset = (int64_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW; - uint32_t row_offset = (uint32_t)row * stride_in_bytes_ + binfo.bidx * BYTES_PER_ROW; + // int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW; + uint32_t row_offset = (uint32_t)((binfo.sum_s + row) * row_stride_in_bytes); + row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT); // Assemble the final pointer. ptr_ += row_offset + col * BYTES_PER_STG; @@ -224,25 +218,19 @@ struct Gmem_tile_o { } } - template - inline __device__ Gmem_tile_o(const Params ¶ms, const BInfo &binfo, const int tidx) - : Gmem_tile_o(params.o_ptr, params.o_stride_in_elts, binfo, tidx) {} - // Store data to global memory. inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) { int row_ = tidx_ / THREADS_PER_ROW; #pragma unroll for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) { int jj = mi * STGS_PER_LOOP + ii; - // if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) { - // break; if( row_ + jj * ROWS_PER_STG >= this->actual_seqlen ) { break; } if (BYTES_PER_ELEMENT == 4) { if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, src[ii]); + fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]); } } else if (BYTES_PER_ELEMENT == 2) { float x = reinterpret_cast(src[ii].x); @@ -251,7 +239,7 @@ struct Gmem_tile_o { float w = reinterpret_cast(src[ii].w); uint2 out = float4_to_half4(x, y, z, w); if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_, out); + fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out); } } } @@ -269,37 +257,26 @@ struct Gmem_tile_o { } if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) { - fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->stride_in_bytes_); + fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes); } } } - // Move the pointer to the next location. - inline __device__ void move() { - // row_ += ROWS; - // ptr_ += (int64_t)ROWS * stride_in_bytes_; - ptr_ += (uint32_t)ROWS * stride_in_bytes_; - actual_seqlen -= ROWS; - } - - inline __device__ void move(const int steps) { + inline __device__ void move(const int steps = 1) { // row_ += ROWS * steps; - // ptr_ += (int64_t)ROWS * stride_in_bytes_ * steps; - ptr_ += (uint32_t)ROWS * stride_in_bytes_ * steps; + // ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps; + ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps; actual_seqlen -= ROWS * steps; } // The stride between rows for the QKV matrice. - // int64_t stride_in_bytes_; - uint32_t stride_in_bytes_; + // int64_t row_stride_in_bytes; + const uint32_t row_stride_in_bytes; // The pointer. char *ptr_; // Is the thread active for the last STG? int is_active_for_last_stg_; - // Keep track of the row to disable loads. - // int row_; // The length of the sequence loaded by that memory tile. - const int actual_seqlen_; int actual_seqlen; const int tidx_; }; @@ -363,10 +340,7 @@ struct Gmem_tile_mma_sd { } // Move to the next tile. - inline __device__ void move() { - ptr_ += LOOP_STRIDE_BYTES; - } - inline __device__ void move(const int steps) { + inline __device__ void move(const int steps = 1) { ptr_ += LOOP_STRIDE_BYTES * steps; } @@ -459,69 +433,6 @@ struct Gmem_tile_mma_s : public Base { //////////////////////////////////////////////////////////////////////////////////////////////////// -template< - // The dimensions of the tile computed by the CTA. - typename Cta_tile, - // The base class. - typename Base = fmha::Gmem_tile_qkv -> -struct Gmem_tile_dout : public Base { - - // Ctor. - template - inline __device__ Gmem_tile_dout(void *ptr, const Params ¶ms, const BInfo &binfo, int tidx) - : Base(params, 0, binfo, tidx) { - - // this->qkv_ptr_ = reinterpret_cast(params.do_ptr); - this->qkv_ptr_ = static_cast(ptr); - this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / Base::THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - // int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - // int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - uint32_t row_offset = (uint32_t)row * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< typename Cta_tile, typename Base = fmha::Gmem_tile_o > -struct Gmem_tile_dq : public Base { - - // Ctor. - template - inline __device__ Gmem_tile_dq(const Params ¶ms, const int qkv_offset, const BInfo &binfo, int tidx) - : Base(params.dqkv_ptr, params.qkv_stride_in_elts, binfo, tidx) { - this->ptr_ = reinterpret_cast(params.dqkv_ptr); - - // Compute the position in the sequence (within the CTA for the moment). - int row = tidx / Base::THREADS_PER_ROW; - // Compute the position of the thread in the row. - int col = tidx % Base::THREADS_PER_ROW; - - // The row offset in the batched GEMM. For each seq element, we store O in that order. - // int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes + - // ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; - // int64_t row_offset = (int64_t)row * this->stride_in_bytes_ + - // ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; - uint32_t row_offset = (uint32_t)row * this->stride_in_bytes_ + - ((binfo.sum_s * 3 + qkv_offset) * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW; - - // Assemble the final pointer. - this->ptr_ += row_offset + col * Base::BYTES_PER_STG; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template< // The dimensions of the tile computed by the CTA. typename Cta_tile diff --git a/csrc/flash_attn/src/fmha/kernel_traits.h b/csrc/flash_attn/src/fmha/kernel_traits.h index 52fd744..401d080 100644 --- a/csrc/flash_attn/src/fmha/kernel_traits.h +++ b/csrc/flash_attn/src/fmha/kernel_traits.h @@ -72,9 +72,7 @@ struct FMHA_kernel_traits { // The shared memory tile to transpose S. using Smem_tile_st = fmha::Smem_tile_mma_transposed; - using Gmem_tile_do = fmha::Gmem_tile_dout; - - using Gmem_tile_dot = fmha::Gmem_tile_dout >; + using Gmem_tile_do = fmha::Gmem_tile_qkv; // The global memory tile to store the softmax sum. using Gmem_softmax_sum = fmha::Gmem_summary_stats; diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index f65b696..a6dc8ff 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -77,8 +77,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, using Gmem_tile_o = Gmem_tile_do; // The global memory tile to store dQ. - // using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq; - using Gmem_tile_dq = fmha::Gmem_tile_dq; + using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; // The shared memory tile to swizzle dQ. using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; @@ -139,19 +138,19 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for dQ. - Gmem_tile_dq gmem_dq(params, 0, binfo, tidx); - Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); + Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); + Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; @@ -161,7 +160,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for dO. - Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx); + Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for dO. Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx); @@ -173,7 +172,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx); + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); @@ -703,11 +702,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dqkv_ptr; - dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); + Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dv.move(loop_step_idx); } @@ -718,11 +713,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // } - Qkv_params dk_params; - dk_params.qkv_ptr = params.dqkv_ptr; - dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); + Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dk.move(loop_step_idx); } diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index 33674f1..5549d4b 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -97,10 +97,10 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c Gemm1 gemm_q_k(smem_, tidx); // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -122,12 +122,12 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V]; - + // Allocate the shared memory tile loader for V. We use the same as K so be careful!!! Smem_tile_v smem_v(smem_v_, tidx); @@ -193,7 +193,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c __syncthreads(); } - // Load the fragments for K. + // Load the fragments for K. gemm_q_k.load_k(); // Create the object to do the softmax. diff --git a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h index 8e32bff..f6f325b 100644 --- a/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h @@ -80,8 +80,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng using Gmem_tile_o = Gmem_tile_do; // The global memory tile to store dQ. - // using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_dq; - using Gmem_tile_dq = fmha::Gmem_tile_dq; + using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o; using Gmem_tile_dq_tmp = fmha::Gmem_tile_o; // The shared memory tile to swizzle dQ. using Smem_tile_dq = typename Kernel_traits::Smem_tile_o; @@ -132,19 +131,19 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for dQ. - Gmem_tile_dq gmem_dq(params, 0, binfo, tidx); - Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); + Gmem_tile_dq gmem_dq(params.dqkv_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); + Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V]; @@ -154,7 +153,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for dO. - Gmem_tile_do gmem_do(params.do_ptr, params, binfo, tidx); + Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for dO. Smem_tile_do smem_do(&smem_[0], tidx); Smem_tile_dot smem_dot(&smem_[0], tidx); @@ -166,7 +165,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx); // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params.o_ptr, params, binfo, tidx); + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the shared memory tile loader for O. We use the same as K so be careful!!! Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx); @@ -654,11 +653,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; smem_dv.load(dv_out); - Qkv_params dv_params; - dv_params.qkv_ptr = params.dqkv_ptr; - dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dv_params.h = params.h; - Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx); + Gmem_tile_dv gmem_dv(params.dqkv_ptr + 2 * params.h * params.d * 2, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dv.move(loop_step_idx); } @@ -669,11 +664,7 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng // for (int ii = 0; ii < Smem_tile_dk::NUM_LDS; ++ii) { // dk_out[ii] = fmha::fmul4(dk_out[ii], params.scale_bmm1f); // } - Qkv_params dk_params; - dk_params.qkv_ptr = params.dqkv_ptr; - dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes; - dk_params.h = params.h; - Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx); + Gmem_tile_dk gmem_dk(params.dqkv_ptr + params.h * params.d * 2, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); if (!Is_first) { gmem_dk.move(loop_step_idx); } diff --git a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h index 9aa9e10..6502878 100644 --- a/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_fprop_kernel_1xN.h @@ -247,10 +247,10 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i Gemm1 gemm_q_k(smem_, tidx); // Allocate the global memory tile loader for Q. - Gmem_tile_q gmem_q(params, 0, binfo, tidx); + Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for O. - Gmem_tile_o gmem_o(params, binfo, tidx); - Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_stride_in_elts, binfo, tidx); + Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); + Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for S. Gmem_tile_s gmem_s(params, binfo, tidx); Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx); @@ -273,9 +273,9 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i fmha::Mask mask(binfo, tidx, loop_step_idx); // Allocate the global memory tile loader for K. - Gmem_tile_k gmem_k(params, 1, binfo, tidx); + Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts, binfo, tidx); // Allocate the global memory tile loader for V. - Gmem_tile_v gmem_v(params, 2, binfo, tidx); + Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts, binfo, tidx); // The base pointer of smem_v; char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];