fix race condition when h < stride_h or w < stride_w (#562)
Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
parent
fb379eaa5b
commit
e7a61c761a
@ -307,6 +307,26 @@ struct ImplicitGemmConvolutionStridedDgrad {
|
||||
int start_r, start_s;
|
||||
params.stride_w_divmod(start_r, start_s, filter_tile_m);
|
||||
|
||||
int filter_r = start_r;
|
||||
int filter_s = start_s;
|
||||
|
||||
if (params.problem_size.mode == Mode::kConvolution) {
|
||||
filter_r = (params.problem_size.R - 1 - filter_r);
|
||||
filter_s = (params.problem_size.S - 1 - filter_s);
|
||||
}
|
||||
|
||||
// Starting h, w positions for filter position in gemm_k=0
|
||||
int start_h, start_w;
|
||||
strided_dgrad_starting_coords(
|
||||
params.problem_size,
|
||||
params.stride_h_divmod, params.stride_w_divmod,
|
||||
filter_r, filter_s,
|
||||
start_h, start_w);
|
||||
|
||||
if (start_h >= params.problem_size.H || start_w >= params.problem_size.W) {
|
||||
return;
|
||||
}
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
Loading…
Reference in New Issue
Block a user