diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 31958a42..65191f5a 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -307,6 +307,26 @@ struct ImplicitGemmConvolutionStridedDgrad { int start_r, start_s; params.stride_w_divmod(start_r, start_s, filter_tile_m); + int filter_r = start_r; + int filter_s = start_s; + + if (params.problem_size.mode == Mode::kConvolution) { + filter_r = (params.problem_size.R - 1 - filter_r); + filter_s = (params.problem_size.S - 1 - filter_s); + } + + // Starting h, w positions for filter position in gemm_k=0 + int start_h, start_w; + strided_dgrad_starting_coords( + params.problem_size, + params.stride_h_divmod, params.stride_w_divmod, + filter_r, filter_s, + start_h, start_w); + + if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) { + return; + } + typename Mma::FragmentC accumulators; accumulators.clear();