Make Python interface work for non-SM80 targets (#726)
* Make Python interface work for non-SM80 targets * Remove line in README
This commit is contained in:
parent
d6117ca362
commit
df81d847d7
@ -2,7 +2,6 @@
|
||||
This directory contains examples of using CUTLASS's Python interface. It consists of two types of examples:
|
||||
* _Basic examples_: minimal examples that illustrate how to set up GEMMs, convolutions, and grouped GEMM operations
|
||||
* [_Customizable examples_](customizable): examples that allow one to specify a variety of template parameters for the given kernel
|
||||
>>>>>>> Add simplified examples
|
||||
|
||||
## Setting up the Python interface
|
||||
Please follow the instructions [here](/tools/library/scripts/pycutlass/README.md#installation) to set up the Python API.
|
||||
|
@ -41,7 +41,7 @@ import sys
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -62,7 +62,7 @@ except:
|
||||
sys.exit(0)
|
||||
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
cc = device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70."
|
||||
|
||||
alignment = 1
|
||||
@ -82,8 +82,17 @@ C = TensorDescription(cutlass.float32, cutlass.TensorNHWC, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
if cc == 70:
|
||||
instruction_shape = [8, 8, 4]
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
|
@ -34,6 +34,7 @@ import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass.utils import reference_model
|
||||
from pycutlass.utils.device import device_cc
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -146,6 +147,11 @@ try:
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
cc = device_cc()
|
||||
if args.compute_capability != cc:
|
||||
raise Exception(("Parameter --compute-capability of {} "
|
||||
"does not match that of the device of {}.").format(args.compute_capability, cc))
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
@ -34,6 +34,7 @@ import pycutlass
|
||||
from pycutlass import *
|
||||
import cutlass
|
||||
from bfloat16 import bfloat16
|
||||
from pycutlass.utils.device import device_cc
|
||||
import sys
|
||||
|
||||
import argparse
|
||||
@ -131,12 +132,16 @@ parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", ty
|
||||
parser.add_argument('--print_cuda', action="store_true",
|
||||
help="print the underlying CUDA kernel")
|
||||
|
||||
|
||||
try:
|
||||
args = parser.parse_args()
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
cc = device_cc()
|
||||
if args.compute_capability != cc:
|
||||
raise Exception(("Parameter --compute-capability of {} "
|
||||
"does not match that of the device of {}.").format(args.compute_capability, cc))
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
pycutlass.compiler.nvcc()
|
||||
|
||||
|
@ -32,6 +32,7 @@
|
||||
import numpy as np
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import csv
|
||||
import sys
|
||||
|
||||
@ -129,6 +130,11 @@ try:
|
||||
except:
|
||||
sys.exit(0)
|
||||
|
||||
cc = device_cc()
|
||||
if args.compute_capability != cc:
|
||||
raise Exception(("Parameter --compute-capability of {} "
|
||||
"does not match that of the device of {}.").format(args.compute_capability, cc))
|
||||
|
||||
pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32)
|
||||
|
||||
np.random.seed(0)
|
||||
|
@ -40,7 +40,7 @@ import sys
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'")
|
||||
@ -55,7 +55,7 @@ except:
|
||||
sys.exit(0)
|
||||
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
cc = device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70."
|
||||
|
||||
alignment = 8
|
||||
@ -78,13 +78,23 @@ C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
if cc == 70:
|
||||
instruction_shape = [8, 8, 4]
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
)
|
||||
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 32], # Threadblock shape
|
||||
2, # Number of stages
|
||||
|
@ -40,7 +40,7 @@ import sys
|
||||
import cutlass
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
import util
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python")
|
||||
@ -52,7 +52,7 @@ except:
|
||||
sys.exit(0)
|
||||
|
||||
# Check that the device is of a sufficient compute capability
|
||||
cc = util.get_device_cc()
|
||||
cc = device_cc()
|
||||
assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70."
|
||||
|
||||
np.random.seed(0)
|
||||
@ -71,8 +71,17 @@ C = TensorDescription(cutlass.float32, cutlass.ColumnMajor, alignment)
|
||||
element_acc = cutlass.float32
|
||||
element_epilogue = cutlass.float32
|
||||
|
||||
# Select instruction shape based on the Tensor Core instructions supported
|
||||
# by the device on which we are running
|
||||
if cc == 70:
|
||||
instruction_shape = [8, 8, 4]
|
||||
elif cc == 75:
|
||||
instruction_shape = [16, 8, 8]
|
||||
else:
|
||||
instruction_shape = [16, 8, 16]
|
||||
|
||||
math_inst = MathInstruction(
|
||||
[16, 8, 8], # Shape of the Tensor Core instruction
|
||||
instruction_shape,
|
||||
A.element, B.element, element_acc,
|
||||
cutlass.OpClass.TensorOp,
|
||||
MathOperation.multiply_add
|
||||
|
@ -102,8 +102,10 @@ Examples can be found in [$CUTLASS_PATH/examples/40_cutlass_py](examples/40_cutl
|
||||
## Test
|
||||
The test cases are listed in `$CUTLASS_PATH//tools/library/scripts/pycutlass/test`. The unit test can be run with
|
||||
```shell
|
||||
# Each of these tests are only supported on devices with compute capability of SM80. For other devices,
|
||||
# see the basic examples in $CUTLASS_PATH/examples/40_cutlass_py
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/unit && python test_sm80.py
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && run_all_example.sh
|
||||
cd $CUTLASS_PATH/tools/library/scripts/pycutlass/test/example && bash run_all_example.sh
|
||||
```
|
||||
|
||||
## build documentation
|
||||
|
@ -308,7 +308,7 @@ class ArtifactManager:
|
||||
cmd = "echo '%s'|g++ -x c++ -fpermissive -w -fPIC" % source_buffer_host
|
||||
for opt in options:
|
||||
opt = opt.decode("utf-8")
|
||||
if opt not in ['-default-device', '-std=c++11', '-arch=sm_80', '-Xcicc', '-Xllc']:
|
||||
if opt not in ['-default-device', '-std=c++11', '-Xcicc', '-Xllc'] and '-arch=sm_' not in opt:
|
||||
if '--include-path=' in opt:
|
||||
cmd += " " + opt.replace('--include-path=', '-I')
|
||||
else:
|
||||
|
@ -31,14 +31,22 @@
|
||||
#################################################################################################
|
||||
|
||||
"""
|
||||
Utility functions for interacting with device
|
||||
Utility functions for interacting with the device
|
||||
"""
|
||||
|
||||
from cuda import cudart
|
||||
|
||||
|
||||
# Raises an exception if `result` returned an error. Otherwise returns the result.
|
||||
def check_cuda_errors(result: list):
|
||||
"""
|
||||
Checks whether `result` contains a CUDA error raises the error as an exception, if so. Otherwise,
|
||||
returns the result contained in the remaining fields of `result`.
|
||||
|
||||
:param result: the results of the `cudart` method, consisting of an error code and any method results
|
||||
:type result: list
|
||||
|
||||
:return: non-error-code results from the `results` parameter
|
||||
"""
|
||||
# `result` is of the format : (cudaError_t, result...)
|
||||
err = result[0]
|
||||
if err.value:
|
||||
@ -52,8 +60,16 @@ def check_cuda_errors(result: list):
|
||||
return result[1:]
|
||||
|
||||
|
||||
# Returns the integer representation of the device compute capability
|
||||
def get_device_cc(device: int = 0):
|
||||
def device_cc(device: int = 0) -> int:
|
||||
"""
|
||||
Returns the compute capability of the device with ID `device`.
|
||||
|
||||
:param device: ID of the device to query
|
||||
:type device: int
|
||||
|
||||
:return: compute capability of the queried device (e.g., 80 for SM80)
|
||||
:rtype: int
|
||||
"""
|
||||
deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device))
|
||||
major = str(deviceProp.major)
|
||||
minor = str(deviceProp.minor)
|
@ -2,9 +2,11 @@
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dDgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dDgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_unity_stride_stage3(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -3,8 +3,11 @@ import pycutlass
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dDgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dDgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -1,8 +1,11 @@
|
||||
# test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu
|
||||
import pycutlass
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
def conv2d_few_channel_problemsizes(channels):
|
||||
problem_sizes = [
|
||||
cutlass.conv.Conv2dProblemSize(
|
||||
|
@ -1,8 +1,11 @@
|
||||
# test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu
|
||||
import pycutlass
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
def conv2d_fixed_channel_problemsizes(channels):
|
||||
problem_sizes = [
|
||||
cutlass.conv.Conv2dProblemSize(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dFpropImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dFpropImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -3,8 +3,11 @@ import pycutlass
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dFpropImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dFpropImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dStridedDgradImplicitGemmF16NHWCF16NHWCF32NHWCTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Strided_Dgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dWgradImplicitGemmF16nhwcF16nhwcF16nhwcTensorOpF16SM80(unittest.TestCase):
|
||||
def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dWgradImplicitGemmF16nhwcF16nhwcF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -3,8 +3,11 @@ import pycutlass
|
||||
from pycutlass.conv2d_operation import *
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dWgradImplicitGemmF32nhwcF32nhwcF32nhwcSimtF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -2,8 +2,11 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class Conv2dWgradImplicitGemmTF32nhwcTF32nhwcTF32nhwcTensorOpF32SM80(unittest.TestCase):
|
||||
def test_SM80_Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -33,6 +33,7 @@
|
||||
import pycutlass
|
||||
import unittest
|
||||
from pycutlass import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import torch
|
||||
import cupy as cp
|
||||
|
||||
@ -42,13 +43,18 @@ class Test_Frontend(unittest.TestCase):
|
||||
#
|
||||
# define the cutlass operator
|
||||
#
|
||||
cc = device_cc()
|
||||
math_inst = MathInstruction(
|
||||
[1, 1, 1], cutlass.float32, cutlass.float32, cutlass.float32,
|
||||
cutlass.OpClass.Simt, MathOperation.multiply_add
|
||||
)
|
||||
|
||||
# Stages > 2 is supported only for compute capability 80 and beyond
|
||||
stages = 4 if cc >= 80 else 2
|
||||
|
||||
|
||||
tile_description = TileDescription(
|
||||
[128, 128, 8], 4, [2, 4, 1],
|
||||
[128, 128, 8], stages, [2, 4, 1],
|
||||
math_inst
|
||||
)
|
||||
|
||||
@ -69,7 +75,7 @@ class Test_Frontend(unittest.TestCase):
|
||||
math_inst.element_accumulator, cutlass.float32)
|
||||
|
||||
self.operation = GemmOperationUniversal(
|
||||
arch=80, tile_description=tile_description,
|
||||
arch=cc, tile_description=tile_description,
|
||||
A=A, B=B, C=C,
|
||||
epilogue_functor=epilogue_functor,
|
||||
swizzling_functor=cutlass.IdentitySwizzle1
|
||||
|
@ -4,7 +4,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_testbed import test_all_gemm
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmBF16TensorOpSm80(unittest.TestCase):
|
||||
def SM80_Device_Gemm_bf16n_bf16n_f32t_tensor_op_f32_64x128x64_32x64x64(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -4,8 +4,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_testbed import test_all_gemm
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmF16Sm80(unittest.TestCase):
|
||||
def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -5,8 +5,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_testbed import test_all_gemm
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmF32nF32nF32nTensorOpF32Sm80(unittest.TestCase):
|
||||
def test_SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -4,7 +4,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_testbed import test_all_gemm
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmF64TensorOpSm80(unittest.TestCase):
|
||||
def test_SM80_Device_Gemm_f64n_f64t_f64t_tensor_op_f64_32x32x16_16x16x16(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -4,8 +4,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_grouped_testbed import TestbedGrouped
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmGroupedSm80(unittest.TestCase):
|
||||
def test_SM80_Device_GemmGrouped_f16n_f16t_f32n_tensor_op_f32_128x128x32_64x64x32(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -5,7 +5,10 @@ from pycutlass.test import *
|
||||
import unittest
|
||||
|
||||
from pycutlass.test.gemm_testbed import test_all_gemm
|
||||
from pycutlass.utils.device import device_cc
|
||||
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
class GemmS8TensorOpF32Sm80(unittest.TestCase):
|
||||
def test_SM80_Device_Gemm_s8t_s8n_s8t_tensor_op_s32_64x64x64_32x32x64(self):
|
||||
math_inst = MathInstruction(
|
||||
|
@ -35,12 +35,14 @@
|
||||
import pycutlass
|
||||
from pycutlass import *
|
||||
from pycutlass.test import *
|
||||
from pycutlass.utils.device import device_cc
|
||||
import unittest
|
||||
|
||||
#
|
||||
# Create GEMM operation
|
||||
#
|
||||
|
||||
@unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for SM80 tests.")
|
||||
def TestGemmOperator(gemm_kind, math_inst, layout, alignment, tiling, arch, mixed=False,
|
||||
epilogue_functor=None, swizzling_functor=cutlass.IdentitySwizzle1, **kwargs):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user