Add some can implement rules of hopper convolution. (#1835)

This commit is contained in:
Junkai-Wu 2024-09-25 23:28:10 +08:00 committed by GitHub
parent 44dae8b90e
commit e2b0789927
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -375,6 +375,61 @@ public:
return false;
}
if (is_im2col_A || is_im2col_B) {
// Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1]
constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1);
auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape);
for (int i = 0; i < problem_shape.RankS; ++i) {
implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1);
}
auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape);
for (int i = 0; i < problem_shape.RankS; ++i) {
implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1);
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n");
return false;
}
}
// Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized)
if constexpr (ConvOp == conv::Operator::kWgrad) {
const auto & input_shape = problem_shape.shape_A;
const auto & input_stride = problem_shape.stride_A;
implementable &= input_stride[ProblemShape::RankT - 1] == 1;
int input_shape_size = 1;
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
input_shape_size *= input_shape[i + 1];
implementable &= input_stride[i] == input_shape_size;
}
const auto & output_shape = problem_shape.shape_C;
const auto & output_stride = problem_shape.stride_C;
implementable &= output_stride[ProblemShape::RankT - 1] == 1;
int output_shape_size = 1;
for (int i = ProblemShape::RankT - 2; i >= 0; --i) {
output_shape_size *= output_shape[i + 1];
implementable &= output_stride[i] == output_shape_size;
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n");
return false;
}
}
// Conv kernels only support cross correlation mode currently.
implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation;
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n");
return false;
}
if (problem_shape.groups > 1) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n");
return false;