expose stream API in python kernel call interfaces (#1287)
* expose stream API in python kernel call interfaces * add stream to ReductionArguments; document stream arg * add stream argument to GemmGroupedArguments
This commit is contained in:
parent
d4be5ab5d7
commit
8ac2edc810
@ -62,6 +62,11 @@ class ArgumentBase:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
|
||||
# RMM buffers used to track tensor lifetime
|
||||
self.buffers = {}
|
||||
# Host tensor to copy the computed result back
|
||||
|
@ -97,6 +97,8 @@ class Conv2dArguments(ArgumentBase):
|
||||
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_size, A, B, C, D,
|
||||
@ -448,6 +450,7 @@ class Conv2dOperation:
|
||||
arguments.host_workspace,
|
||||
arguments.device_workspace,
|
||||
arguments.launch_config,
|
||||
arguments.stream
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
|
@ -164,6 +164,9 @@ class GemmArguments2x(ArgumentBase):
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
||||
@ -666,6 +669,9 @@ class GemmGroupedArguments:
|
||||
|
||||
:param output_op: output operator, optional
|
||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
"""
|
||||
|
||||
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
|
||||
@ -766,6 +772,11 @@ class GemmGroupedArguments:
|
||||
else:
|
||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
|
||||
# Get host problem size
|
||||
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
||||
|
||||
@ -1542,6 +1553,7 @@ class GemmOperationBase:
|
||||
arguments.host_workspace,
|
||||
arguments.device_workspace,
|
||||
arguments.launch_config,
|
||||
arguments.stream
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
|
@ -79,6 +79,10 @@ class ReductionArguments:
|
||||
else:
|
||||
# by default, tensor_C is not bias
|
||||
self.bias = False
|
||||
if "stream" in kwargs.keys():
|
||||
self.stream = kwargs["stream"]
|
||||
else:
|
||||
self.stream = cuda.CUstream(0)
|
||||
|
||||
self.operation = operation
|
||||
self.ptr_workspace = workspace
|
||||
@ -386,6 +390,7 @@ class ReductionOperation:
|
||||
host_workspace,
|
||||
device_workspace,
|
||||
launch_config,
|
||||
arguments.stream
|
||||
)
|
||||
|
||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||
|
@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class Conv2d(OperationBase):
|
||||
@ -733,7 +734,8 @@ class Conv2d(OperationBase):
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||
alpha=None, beta=None,
|
||||
split_k=("serial", 1), sync: bool = True,
|
||||
print_module: bool = False) -> Conv2dArguments:
|
||||
print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -760,6 +762,8 @@ class Conv2d(OperationBase):
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.Conv2dArguments
|
||||
@ -850,7 +854,8 @@ class Conv2d(OperationBase):
|
||||
A=A, B=B, C=C, D=D,
|
||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
||||
split_k_slices=split_k[1]
|
||||
split_k_slices=split_k[1],
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
@ -864,7 +869,8 @@ class Conv2d(OperationBase):
|
||||
workspace=arguments.ptr_D,
|
||||
destination=D,
|
||||
source=C,
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
|
||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
||||
stream=stream
|
||||
)
|
||||
self.reduction_operation.run(reduction_arguments)
|
||||
|
||||
@ -919,11 +925,12 @@ class Conv2dFprop(Conv2d):
|
||||
def run(
|
||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
|
||||
A, B, D = input, weight, output
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dDgrad(Conv2d):
|
||||
@ -943,11 +950,12 @@ class Conv2dDgrad(Conv2d):
|
||||
|
||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, weight, grad_input
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
||||
|
||||
class Conv2dWgrad(Conv2d):
|
||||
@ -967,8 +975,9 @@ class Conv2dWgrad(Conv2d):
|
||||
|
||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
||||
sync: bool = True, print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||
#
|
||||
A, B, D = grad_output, input, grad_weight
|
||||
return super().run(
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||
|
@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
||||
from cutlass.op.op import OperationBase
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class Gemm(OperationBase):
|
||||
@ -621,7 +622,8 @@ class Gemm(OperationBase):
|
||||
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
||||
|
||||
def run(self, A=None, B=None, C=None, D=None,
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:
|
||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments:
|
||||
"""
|
||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||
@ -644,6 +646,8 @@ class Gemm(OperationBase):
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmArguments
|
||||
@ -687,6 +691,7 @@ class Gemm(OperationBase):
|
||||
'D': self._get_batch_stride(D)
|
||||
}
|
||||
}
|
||||
kwargs['stream'] = stream
|
||||
|
||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||
output_op = self.operation.epilogue_type(visitor_args)
|
||||
|
@ -65,6 +65,7 @@ from cutlass.backend.library import (
|
||||
from cutlass.op.gemm import Gemm
|
||||
from cutlass.shape import GemmCoord
|
||||
from cutlass.utils import check, datatypes
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
class GroupedGemm(Gemm):
|
||||
@ -194,7 +195,8 @@ class GroupedGemm(Gemm):
|
||||
|
||||
def run(self, A, B, C, D,
|
||||
alpha=None, beta=None, sync: bool = True,
|
||||
print_module: bool = False) -> GemmGroupedArguments:
|
||||
print_module: bool = False,
|
||||
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments:
|
||||
"""
|
||||
Runs the kernel currently specified.
|
||||
|
||||
@ -217,6 +219,8 @@ class GroupedGemm(Gemm):
|
||||
:type sync: bool
|
||||
:param print_module: whether to print the emitted C++ code
|
||||
:type print_module: bool
|
||||
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||
:type stream: :class:`cuda.cuda.CUstream`
|
||||
|
||||
:return: arguments passed in to the kernel
|
||||
:rtype: cutlass.backend.GemmGroupedArguments
|
||||
@ -248,7 +252,8 @@ class GroupedGemm(Gemm):
|
||||
operation=self.operation,
|
||||
problem_sizes=problem_sizes,
|
||||
A=As, B=Bs, C=Cs, D=Ds,
|
||||
output_op=self.operation.epilogue_type(alpha, beta)
|
||||
output_op=self.operation.epilogue_type(alpha, beta),
|
||||
stream=stream
|
||||
)
|
||||
|
||||
self.operation.run(arguments)
|
||||
|
Loading…
Reference in New Issue
Block a user