509 lines
20 KiB
Python
509 lines
20 KiB
Python
![]() |
#################################################################################################
|
||
|
#
|
||
|
# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
# SPDX-License-Identifier: BSD-3-Clause
|
||
|
#
|
||
|
# Redistribution and use in source and binary forms, with or without
|
||
|
# modification, are permitted provided that the following conditions are met:
|
||
|
#
|
||
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
||
|
# list of conditions and the following disclaimer.
|
||
|
#
|
||
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
||
|
# this list of conditions and the following disclaimer in the documentation
|
||
|
# and/or other materials provided with the distribution.
|
||
|
#
|
||
|
# 3. Neither the name of the copyright holder nor the names of its
|
||
|
# contributors may be used to endorse or promote products derived from
|
||
|
# this software without specific prior written permission.
|
||
|
#
|
||
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||
|
#
|
||
|
#################################################################################################
|
||
|
|
||
|
"""
|
||
|
Util Functions for Conv2d Test
|
||
|
"""
|
||
|
import torch
|
||
|
import cutlass
|
||
|
import unittest
|
||
|
import cutlass_bindings
|
||
|
from cutlass.utils.datatypes import binding_type, binding_opclass
|
||
|
from cutlass.backend.test.conv2d_testbed import Conv2dLauncher, getTensorRef, getTensorView
|
||
|
from cutlass.backend.utils.device import device_cc
|
||
|
from cutlass.backend.test.utils import get_name_conv2d
|
||
|
import numpy as np
|
||
|
|
||
|
def conv2d_few_channel_problemsizes(channels):
|
||
|
problem_sizes = [
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 8, 8, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(16, 3, 3, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(2, 2),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 16, 16, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(16, 3, 3, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(2, 2),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 16, 16, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(16, 7, 7, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 224, 224, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(32, 7, 7, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 224, 224, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(64, 7, 7, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(2, 2),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 224, 224, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(64, 5, 5, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 224, 224, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(64, 5, 5, channels),
|
||
|
cutlass_bindings.Tensor4DCoord(1, 1, 1, 1),
|
||
|
cutlass_bindings.MatrixCoord(2, 2),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
]
|
||
|
|
||
|
return problem_sizes
|
||
|
|
||
|
torch_dtype = {
|
||
|
cutlass.DataType.f16: torch.float16,
|
||
|
cutlass.DataType.f32: torch.float32,
|
||
|
cutlass.DataType.f64: torch.float64
|
||
|
}
|
||
|
|
||
|
numpy_dtype = {
|
||
|
cutlass.DataType.f16: np.float16,
|
||
|
cutlass.DataType.f32: np.float32,
|
||
|
cutlass.DataType.f64: np.float64
|
||
|
}
|
||
|
|
||
|
|
||
|
def validate_problem_size(ps, conv_kind, split_k_slices):
|
||
|
P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1
|
||
|
Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1
|
||
|
if P != ps.P or Q != ps.Q:
|
||
|
return False
|
||
|
|
||
|
# Split-K (serial or parallel) is not supported for strided dgrad
|
||
|
if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
|
||
|
# Override the backend launcher
|
||
|
class Conv2dLauncherFrontend(Conv2dLauncher):
|
||
|
def __init__(self, plan: cutlass.Conv2d, seed: int = 80, backend="numpy"):
|
||
|
self.operation = plan
|
||
|
self.conv_kind = plan.conv_kind
|
||
|
self.seed = seed
|
||
|
self.backend = backend
|
||
|
|
||
|
self.dtype_A = plan._element_a
|
||
|
self.dtype_B = plan._element_b
|
||
|
self.dtype_C = plan._element_c
|
||
|
self.dtype_acc = plan._element_accumulator
|
||
|
|
||
|
self.layout_A = cutlass_bindings.TensorNHWC
|
||
|
self.layout_B = cutlass_bindings.TensorNHWC
|
||
|
self.layout_C = cutlass_bindings.TensorNHWC
|
||
|
self.layout_D = cutlass_bindings.TensorNHWC
|
||
|
|
||
|
self.element_compute = cutlass_bindings.float32
|
||
|
self.enable_cached_results = True
|
||
|
|
||
|
# Get randomization_max
|
||
|
if self.dtype_A in [cutlass.DataType.f16, cutlass.DataType.bf16]:
|
||
|
if self.dtype_acc in [cutlass.DataType.f16, cutlass.DataType.bf16]:
|
||
|
self.randomization_max = 2
|
||
|
else:
|
||
|
self.randomization_max = 3
|
||
|
else:
|
||
|
self.randomization_max = 7
|
||
|
|
||
|
self.activation = plan.activation
|
||
|
|
||
|
self.host_conv2d = cutlass_bindings.test.conv.host.conv2d
|
||
|
|
||
|
|
||
|
def set_seed(self):
|
||
|
if self.backend == "numpy":
|
||
|
np.random.seed(self.seed)
|
||
|
else:
|
||
|
torch.manual_seed(self.seed)
|
||
|
|
||
|
def uniform_init(self, size, dtype):
|
||
|
if self.backend == "numpy":
|
||
|
return super().uniform_init(size, numpy_dtype[dtype])
|
||
|
else:
|
||
|
tensor = torch.ceil(
|
||
|
torch.empty(size=size, dtype=torch_dtype[dtype], device="cuda").uniform_(-self.randomization_max - 0.5, self.randomization_max - 0.5)
|
||
|
).to(memory_format=torch.channels_last)
|
||
|
return tensor
|
||
|
|
||
|
def zeros_like(self, tensor):
|
||
|
if self.backend == "numpy":
|
||
|
return np.zeros_like(tensor)
|
||
|
else:
|
||
|
return torch.zeros_like(tensor).to(memory_format=torch.channels_last)
|
||
|
|
||
|
def reference(self, ps, A, B, C, alpha, beta, activation):
|
||
|
if self.backend == "numpy":
|
||
|
numpy_result = self.host_reference(ps, A, B, C, alpha, beta, activation)
|
||
|
return numpy_result
|
||
|
else:
|
||
|
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||
|
torch_result = alpha * torch.ops.aten.conv2d(
|
||
|
A,
|
||
|
B,
|
||
|
stride=(ps.stride_h, ps.stride_w),
|
||
|
padding=(ps.pad_h, ps.pad_w),
|
||
|
dilation=(ps.dilation_h, ps.dilation_w)
|
||
|
) + beta * C
|
||
|
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||
|
torch_result = alpha * torch.nn.grad.conv2d_input(
|
||
|
(ps.N, ps.C, ps.H, ps.W),
|
||
|
B,
|
||
|
A,
|
||
|
padding=(ps.pad_h, ps.pad_w),
|
||
|
stride=(ps.stride_h, ps.stride_w)
|
||
|
) + beta * C
|
||
|
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||
|
torch_result = alpha * torch.nn.grad.conv2d_weight(
|
||
|
B,
|
||
|
(ps.K, ps.C, ps.R, ps.S),
|
||
|
A,
|
||
|
padding=(ps.pad_h, ps.pad_w),
|
||
|
stride=(ps.stride_h, ps.stride_w)
|
||
|
) + beta * C
|
||
|
else:
|
||
|
raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.")
|
||
|
|
||
|
if activation == cutlass.backend.epilogue.relu:
|
||
|
torch_result = torch.nn.functional.relu(torch_result)
|
||
|
elif activation == cutlass.backend.epilogue.leaky_relu:
|
||
|
torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5)
|
||
|
|
||
|
return torch_result
|
||
|
|
||
|
def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta, activation):
|
||
|
if self.element_compute == cutlass_bindings.float16:
|
||
|
alpha = cutlass_bindings.float16(alpha)
|
||
|
beta = cutlass_bindings.float16(beta)
|
||
|
elif self.element_compute == cutlass_bindings.int32:
|
||
|
alpha = int(alpha)
|
||
|
beta = int(beta)
|
||
|
else:
|
||
|
alpha = alpha
|
||
|
beta = beta
|
||
|
|
||
|
# If cached result is loaded
|
||
|
cached_result_loaded = False
|
||
|
|
||
|
if self.enable_cached_results:
|
||
|
# Get problem key
|
||
|
cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey(
|
||
|
self.conv_kind,
|
||
|
problem_size,
|
||
|
alpha,
|
||
|
beta,
|
||
|
getTensorView(
|
||
|
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
|
||
|
),
|
||
|
getTensorView(
|
||
|
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
|
||
|
),
|
||
|
getTensorView(
|
||
|
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
|
||
|
),
|
||
|
)
|
||
|
|
||
|
cached_test_key.problem = cached_test_key.problem + f"_{activation.tag.split('::')[-1]}"
|
||
|
|
||
|
cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult()
|
||
|
|
||
|
conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % (
|
||
|
self.operation.arch,
|
||
|
self.seed,
|
||
|
)
|
||
|
|
||
|
cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing(
|
||
|
conv2d_result_cache_name
|
||
|
)
|
||
|
# CachedTestResultListing cached_results(conv2d_result_cache_name);
|
||
|
cached = cached_results.find(cached_test_key)
|
||
|
cached_result_loaded = cached[0]
|
||
|
if cached_result_loaded:
|
||
|
cached_test_result = cached[1]
|
||
|
|
||
|
if not cached_result_loaded:
|
||
|
# Compute the conv2d on host
|
||
|
tensor_D_ref = np.ones_like(tensor_C)
|
||
|
tensor_ref_A = getTensorRef(
|
||
|
tensor_A, self.layout_A, self.conv_kind, problem_size, "a"
|
||
|
)
|
||
|
tensor_ref_B = getTensorRef(
|
||
|
tensor_B, self.layout_B, self.conv_kind, problem_size, "b"
|
||
|
)
|
||
|
tensor_ref_C = getTensorRef(
|
||
|
tensor_C, self.layout_C, self.conv_kind, problem_size, "c"
|
||
|
)
|
||
|
tensor_ref_D_ref = getTensorRef(
|
||
|
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
|
||
|
)
|
||
|
|
||
|
self.host_conv2d(
|
||
|
self.conv_kind,
|
||
|
problem_size,
|
||
|
tensor_ref_A,
|
||
|
tensor_ref_B,
|
||
|
tensor_ref_C,
|
||
|
tensor_ref_D_ref,
|
||
|
alpha,
|
||
|
beta,
|
||
|
)
|
||
|
|
||
|
if activation == cutlass.backend.epilogue.leaky_relu:
|
||
|
tensor_D_ref = activation.numpy(tensor_D_ref, 0.5)
|
||
|
else:
|
||
|
tensor_D_ref = activation.numpy(tensor_D_ref)
|
||
|
|
||
|
tensor_view_D_ref = getTensorView(
|
||
|
tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d"
|
||
|
)
|
||
|
|
||
|
if self.enable_cached_results:
|
||
|
cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash(
|
||
|
tensor_view_D_ref
|
||
|
)
|
||
|
cached_results = (
|
||
|
cutlass_bindings.test.conv.host.CachedTestResultListing(
|
||
|
conv2d_result_cache_name
|
||
|
)
|
||
|
)
|
||
|
cached_results.append(cached_test_key, cached_test_result)
|
||
|
cached_results.write(conv2d_result_cache_name)
|
||
|
else:
|
||
|
return tensor_D_ref
|
||
|
|
||
|
return cached_test_result.D
|
||
|
|
||
|
def equal(self, tensor_D, tensor_D_ref, problem_size):
|
||
|
if self.backend == "numpy":
|
||
|
return super().equal(tensor_D, tensor_D_ref, problem_size)
|
||
|
else:
|
||
|
torch.cuda.synchronize()
|
||
|
return torch.equal(tensor_D, tensor_D_ref)
|
||
|
|
||
|
|
||
|
def run(self, ps, split_k_mode=cutlass_bindings.conv.SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0):
|
||
|
|
||
|
#
|
||
|
# Initialize input and output tensors
|
||
|
#
|
||
|
if self.conv_kind == cutlass_bindings.conv.Operator.fprop:
|
||
|
if self.backend == "torch":
|
||
|
tensor_A_size = (ps.N, ps.C, ps.H, ps.W)
|
||
|
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
||
|
tensor_C_size = (ps.N, ps.K, ps.P, ps.Q)
|
||
|
else:
|
||
|
tensor_A_size = (ps.N, ps.H, ps.W, ps.C)
|
||
|
tensor_B_size = (ps.K, ps.R, ps.S, ps.C)
|
||
|
tensor_C_size = (ps.N, ps.P, ps.Q, ps.K)
|
||
|
elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad:
|
||
|
if self.backend == "torch":
|
||
|
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
||
|
tensor_B_size = (ps.K, ps.C, ps.R, ps.S)
|
||
|
tensor_C_size = (ps.N, ps.C, ps.H, ps.W)
|
||
|
else:
|
||
|
tensor_A_size = (ps.N, ps.P, ps.Q, ps.K)
|
||
|
tensor_B_size = (ps.K, ps.R, ps.S, ps.C)
|
||
|
tensor_C_size = (ps.N, ps.H, ps.W, ps.C)
|
||
|
elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad:
|
||
|
if self.backend == "torch":
|
||
|
tensor_A_size = (ps.N, ps.K, ps.P, ps.Q)
|
||
|
tensor_B_size = (ps.N, ps.C, ps.H, ps.W)
|
||
|
tensor_C_size = (ps.K, ps.C, ps.R, ps.S)
|
||
|
else:
|
||
|
tensor_A_size = (ps.N, ps.P, ps.Q, ps.K)
|
||
|
tensor_B_size = (ps.N, ps.H, ps.W, ps.C)
|
||
|
tensor_C_size = (ps.K, ps.R, ps.S, ps.C)
|
||
|
else:
|
||
|
raise Exception(f"Conv kind {self.conv_kind} is not supported")
|
||
|
|
||
|
self.set_seed()
|
||
|
|
||
|
tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A)
|
||
|
tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B)
|
||
|
tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C)
|
||
|
tensor_D = self.zeros_like(tensor_C)
|
||
|
|
||
|
self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D,
|
||
|
stride=(ps.stride_h, ps.stride_w),
|
||
|
padding=(ps.pad_h, ps.pad_w),
|
||
|
dilation=(ps.dilation_h, ps.dilation_w),
|
||
|
alpha=alpha, beta=beta,
|
||
|
split_k=(split_k_mode, split_k_slices))
|
||
|
|
||
|
tensor_D_ref = self.reference(
|
||
|
ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation
|
||
|
)
|
||
|
|
||
|
return self.equal(tensor_D, tensor_D_ref, ps)
|
||
|
|
||
|
|
||
|
def add_test(
|
||
|
cls,
|
||
|
cc,
|
||
|
conv_kind,
|
||
|
problem_sizes,
|
||
|
element,
|
||
|
element_accumulator,
|
||
|
element_output,
|
||
|
opclass,
|
||
|
threadblock_shape,
|
||
|
warp_count,
|
||
|
instruction_shape,
|
||
|
stages,
|
||
|
iterator_algorithm=None,
|
||
|
swizzle=None,
|
||
|
split_k_mode="serial",
|
||
|
split_k_slices=1,
|
||
|
activation = "identity"
|
||
|
):
|
||
|
"""Create a test-running function with the given specification"""
|
||
|
test_name = get_name_conv2d(
|
||
|
cc, conv_kind, element, element_accumulator,
|
||
|
element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages,
|
||
|
iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation)
|
||
|
|
||
|
def run(self):
|
||
|
# Create the plan
|
||
|
plan = cutlass.Conv2d(
|
||
|
kind=conv_kind,
|
||
|
element=element,
|
||
|
element_accumulator=element_accumulator,
|
||
|
element_C=element_output,
|
||
|
element_D=element_output
|
||
|
)
|
||
|
|
||
|
# Set the opclass
|
||
|
plan.opclass = opclass
|
||
|
# Set the tile description
|
||
|
td = {
|
||
|
"threadblock_shape": threadblock_shape,
|
||
|
"warp_count": warp_count,
|
||
|
"stages": stages,
|
||
|
"instruction_shape": instruction_shape,
|
||
|
}
|
||
|
|
||
|
plan.tile_description = td
|
||
|
# Set iterator algorithm
|
||
|
if iterator_algorithm is not None:
|
||
|
plan.iterator_algorithm = iterator_algorithm
|
||
|
# Set swizzling functor
|
||
|
if swizzle is not None:
|
||
|
plan.swizzling_stride = swizzle
|
||
|
|
||
|
if activation != "identity":
|
||
|
if activation == "leaky_relu":
|
||
|
plan.activation = (cutlass.epilogue.leaky_relu, 0.5)
|
||
|
else:
|
||
|
plan.activation = getattr(cutlass.epilogue, activation)
|
||
|
|
||
|
conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="numpy")
|
||
|
|
||
|
for ps in problem_sizes:
|
||
|
if not validate_problem_size(ps, conv_kind, split_k_slices): continue
|
||
|
|
||
|
self.assertTrue(
|
||
|
conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 0.5)
|
||
|
)
|
||
|
|
||
|
setattr(cls, test_name, run)
|
||
|
|
||
|
return run
|
||
|
|
||
|
|
||
|
def get_conv_problems():
|
||
|
# 64: minimum channel size
|
||
|
conv_problems = list(cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64).conv2d_default_sizes)
|
||
|
# Insert alignment 4 & 2 tests
|
||
|
conv_problems += [
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 4, 4, 12),
|
||
|
cutlass_bindings.Tensor4DCoord(8, 3, 3, 12),
|
||
|
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0),
|
||
|
cutlass_bindings.MatrixCoord(3, 3),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 4, 4, 14),
|
||
|
cutlass_bindings.Tensor4DCoord(8, 3, 3, 14),
|
||
|
cutlass_bindings.Tensor4DCoord(0, 0, 0, 0),
|
||
|
cutlass_bindings.MatrixCoord(3, 3),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
cutlass_bindings.conv.Conv2dProblemSize(
|
||
|
cutlass_bindings.Tensor4DCoord(1, 23, 56, 98),
|
||
|
cutlass_bindings.Tensor4DCoord(128, 3, 3, 98),
|
||
|
cutlass_bindings.Tensor4DCoord(4, 0, 5, 0),
|
||
|
cutlass_bindings.MatrixCoord(3, 3),
|
||
|
cutlass_bindings.MatrixCoord(1, 1),
|
||
|
cutlass_bindings.conv.Mode.cross_correlation,
|
||
|
1, 1
|
||
|
),
|
||
|
]
|
||
|
|
||
|
return conv_problems
|