Add some can implement rules of hopper convolution. (#1835)
This commit is contained in:
parent
44dae8b90e
commit
e2b0789927
@ -375,6 +375,61 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
if (is_im2col_A || is_im2col_B) {
|
||||
// Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
|
||||
constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
|
||||
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
|
||||
for (int i = 0; i < problem_shape.RankS; ++i) {
|
||||
implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
|
||||
}
|
||||
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
|
||||
for (int i = 0; i < problem_shape.RankS; ++i) {
|
||||
implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
|
||||
if constexpr (ConvOp == conv::Operator::kWgrad) {
|
||||
|
||||
const auto & input_shape = problem_shape.shape_A;
|
||||
const auto & input_stride = problem_shape.stride_A;
|
||||
|
||||
implementable &= input_stride[ProblemShape::RankT - 1] == 1;
|
||||
int input_shape_size = 1;
|
||||
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
||||
input_shape_size *= input_shape[i + 1];
|
||||
implementable &= input_stride[i] == input_shape_size;
|
||||
}
|
||||
|
||||
const auto & output_shape = problem_shape.shape_C;
|
||||
const auto & output_stride = problem_shape.stride_C;
|
||||
|
||||
implementable &= output_stride[ProblemShape::RankT - 1] == 1;
|
||||
int output_shape_size = 1;
|
||||
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
|
||||
output_shape_size *= output_shape[i + 1];
|
||||
implementable &= output_stride[i] == output_shape_size;
|
||||
}
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Conv kernels only support cross correlation mode currently.
|
||||
implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
|
||||
|
||||
if (!implementable) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (problem_shape.groups > 1) {
|
||||
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n");
|
||||
return false;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user