diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 74f25ba..322551b 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), Shape, Int>{}, Stride, _1>{}); Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), @@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), Shape, Int>{}, make_stride(params.dq_row_stride, _1{})); @@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.do_row_stride + bidh * params.o_head_stride; + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; // We'll advance gdKaccum and gdVaccum before the first write. const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded + n_block_max * kBlockN) * params.d_rounded; @@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in make_stride(params.do_row_stride, _1{})); Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), Shape, Int>{}, - make_stride(params.do_row_stride, _1{})); + make_stride(params.o_row_stride, _1{})); Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), Shape, Int>{}, Stride, _1>{});