From e7a61c761a4bfb387b61c03cdbcd19ab300726b7 Mon Sep 17 00:00:00 2001 From: Haicheng Wu <57973641+hwu36@users.noreply.github.com> Date: Tue, 12 Jul 2022 16:37:08 -0400 Subject: [PATCH] fix race condition when h < stride_h or w < stride_w (#562) Co-authored-by: Haicheng Wu --- .../implicit_gemm_convolution_strided_dgrad.h | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 31958a42..65191f5a 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -307,6 +307,26 @@ struct ImplicitGemmConvolutionStridedDgrad { int start_r, start_s; params.stride_w_divmod(start_r, start_s, filter_tile_m); + int filter_r = start_r; + int filter_s = start_s; + + if (params.problem_size.mode == Mode::kConvolution) { + filter_r = (params.problem_size.R - 1 - filter_r); + filter_s = (params.problem_size.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + params.problem_size, + params.stride_h_divmod, params.stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { + return; + } + typename Mma::FragmentC accumulators; accumulators.clear();