Fix using dO stride for O, which can cause memory error in bwd
This commit is contained in:
parent
2dd87d0609
commit
9ee0ff1d9b
@ -141,7 +141,7 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
@ -474,7 +474,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
@ -1098,7 +1098,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.o_head_stride;
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
// We'll advance gdKaccum and gdVaccum before the first write.
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + (bidh / params.h_h_k_ratio)) * params.seqlen_k_rounded
|
||||
+ n_block_max * kBlockN) * params.d_rounded;
|
||||
@ -1119,7 +1119,7 @@ inline __device__ void compute_dq_dk_dv_1rowblock(const Params ¶ms, const in
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user