From 9ee0ff1d9b6a99630e2a6868b9291dfa32d35abd Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Thu, 20 Jul 2023 17:39:00 -0700 Subject: [PATCH] Fix using dO stride for O, which can cause memory error in bwd --- csrc/flash_attn/src/flash_bwd_kernel.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 74f25ba..322551b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, Stride, _1>{}); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), @@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); @@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.o_head_stride; + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; // We'll advance gdKaccum and gdVaccum before the first write. const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded + n_block_max * kBlockN) * params.d_rounded; @@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{});