expose stream API in python kernel call interfaces (#1287)

* expose stream API in python kernel call interfaces

* add stream to ReductionArguments; document stream arg

* add stream argument to GemmGroupedArguments
This commit is contained in:
Kun Wu 2024-01-05 07:27:45 -06:00 committed by GitHub
parent d4be5ab5d7
commit 8ac2edc810
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 56 additions and 12 deletions

View File

@ -62,6 +62,11 @@ class ArgumentBase:
# by default, tensor_C is not bias # by default, tensor_C is not bias
self.bias = False self.bias = False
if "stream" in kwargs.keys():
self.stream = kwargs["stream"]
else:
self.stream = cuda.CUstream(0)
# RMM buffers used to track tensor lifetime # RMM buffers used to track tensor lifetime
self.buffers = {} self.buffers = {}
# Host tensor to copy the computed result back # Host tensor to copy the computed result back

View File

@ -97,6 +97,8 @@ class Conv2dArguments(ArgumentBase):
:type split_k_mode: cutlass_library.library.SplitKMode, optional :type split_k_mode: cutlass_library.library.SplitKMode, optional
:param output_op: output operator, optional :param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
""" """
def __init__(self, operation, problem_size, A, B, C, D, def __init__(self, operation, problem_size, A, B, C, D,
@ -448,6 +450,7 @@ class Conv2dOperation:
arguments.host_workspace, arguments.host_workspace,
arguments.device_workspace, arguments.device_workspace,
arguments.launch_config, arguments.launch_config,
arguments.stream
) )
if err != cuda.CUresult.CUDA_SUCCESS: if err != cuda.CUresult.CUDA_SUCCESS:

View File

@ -164,6 +164,9 @@ class GemmArguments2x(ArgumentBase):
:param output_op: output operator, optional :param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
""" """
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs): def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
@ -666,6 +669,9 @@ class GemmGroupedArguments:
:param output_op: output operator, optional :param output_op: output operator, optional
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
""" """
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs): def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
@ -766,6 +772,11 @@ class GemmGroupedArguments:
else: else:
self.output_op = self.operation.epilogue_type(1.0, 0.0) self.output_op = self.operation.epilogue_type(1.0, 0.0)
if "stream" in kwargs.keys():
self.stream = kwargs["stream"]
else:
self.stream = cuda.CUstream(0)
# Get host problem size # Get host problem size
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0] self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
@ -1542,6 +1553,7 @@ class GemmOperationBase:
arguments.host_workspace, arguments.host_workspace,
arguments.device_workspace, arguments.device_workspace,
arguments.launch_config, arguments.launch_config,
arguments.stream
) )
if err != cuda.CUresult.CUDA_SUCCESS: if err != cuda.CUresult.CUDA_SUCCESS:

View File

@ -79,6 +79,10 @@ class ReductionArguments:
else: else:
# by default, tensor_C is not bias # by default, tensor_C is not bias
self.bias = False self.bias = False
if "stream" in kwargs.keys():
self.stream = kwargs["stream"]
else:
self.stream = cuda.CUstream(0)
self.operation = operation self.operation = operation
self.ptr_workspace = workspace self.ptr_workspace = workspace
@ -386,6 +390,7 @@ class ReductionOperation:
host_workspace, host_workspace,
device_workspace, device_workspace,
launch_config, launch_config,
arguments.stream
) )
if err != cuda.CUresult.CUDA_SUCCESS: if err != cuda.CUresult.CUDA_SUCCESS:

View File

@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase from cutlass.op.op import OperationBase
from cutlass.shape import Conv2DProblemSize, MatrixCoord from cutlass.shape import Conv2DProblemSize, MatrixCoord
from cutlass.utils import check, datatypes from cutlass.utils import check, datatypes
from cuda import cuda
class Conv2d(OperationBase): class Conv2d(OperationBase):
@ -733,7 +734,8 @@ class Conv2d(OperationBase):
stride=(1, 1), padding=(0, 0), dilation=(1, 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1),
alpha=None, beta=None, alpha=None, beta=None,
split_k=("serial", 1), sync: bool = True, split_k=("serial", 1), sync: bool = True,
print_module: bool = False) -> Conv2dArguments: print_module: bool = False,
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
""" """
Runs the kernel currently specified. If it has not already been, the kernel is emitted and Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the compiled. Tensors holding operands and outputs of the kernel are sourced either from the
@ -760,6 +762,8 @@ class Conv2d(OperationBase):
:type sync: bool :type sync: bool
:param print_module: whether to print the emitted C++ code :param print_module: whether to print the emitted C++ code
:type print_module: bool :type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel :return: arguments passed in to the kernel
:rtype: cutlass.backend.Conv2dArguments :rtype: cutlass.backend.Conv2dArguments
@ -850,7 +854,8 @@ class Conv2d(OperationBase):
A=A, B=B, C=C, D=D, A=A, B=B, C=C, D=D,
output_op=self.operation.epilogue_type(*epilogue_args), output_op=self.operation.epilogue_type(*epilogue_args),
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]), split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
split_k_slices=split_k[1] split_k_slices=split_k[1],
stream=stream
) )
self.operation.run(arguments) self.operation.run(arguments)
@ -864,7 +869,8 @@ class Conv2d(OperationBase):
workspace=arguments.ptr_D, workspace=arguments.ptr_D,
destination=D, destination=D,
source=C, source=C,
output_op=self.reduction_operation.epilogue_type(*epilogue_args) output_op=self.reduction_operation.epilogue_type(*epilogue_args),
stream=stream
) )
self.reduction_operation.run(reduction_arguments) self.reduction_operation.run(reduction_arguments)
@ -919,11 +925,12 @@ class Conv2dFprop(Conv2d):
def run( def run(
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments: sync: bool = True, print_module: bool = False,
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
A, B, D = input, weight, output A, B, D = input, weight, output
return super().run( return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
class Conv2dDgrad(Conv2d): class Conv2dDgrad(Conv2d):
@ -943,11 +950,12 @@ class Conv2dDgrad(Conv2d):
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments: sync: bool = True, print_module: bool = False,
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
# #
A, B, D = grad_output, weight, grad_input A, B, D = grad_output, weight, grad_input
return super().run( return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
class Conv2dWgrad(Conv2d): class Conv2dWgrad(Conv2d):
@ -967,8 +975,9 @@ class Conv2dWgrad(Conv2d):
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
sync: bool = True, print_module: bool = False) -> Conv2dArguments: sync: bool = True, print_module: bool = False,
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
# #
A, B, D = grad_output, input, grad_weight A, B, D = grad_output, input, grad_weight
return super().run( return super().run(
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)

View File

@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
from cutlass.op.op import OperationBase from cutlass.op.op import OperationBase
from cutlass.shape import GemmCoord from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes from cutlass.utils import check, datatypes
from cuda import cuda
class Gemm(OperationBase): class Gemm(OperationBase):
@ -621,7 +622,8 @@ class Gemm(OperationBase):
f'layout of ({ref_type}, {ref_layout}) and transpose failed.') f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
def run(self, A=None, B=None, C=None, D=None, def run(self, A=None, B=None, C=None, D=None,
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments: alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments:
""" """
Runs the kernel currently specified. If it has not already been, the kernel is emitted and Runs the kernel currently specified. If it has not already been, the kernel is emitted and
compiled. Tensors holding operands and outputs of the kernel are sourced either from the compiled. Tensors holding operands and outputs of the kernel are sourced either from the
@ -644,6 +646,8 @@ class Gemm(OperationBase):
:type sync: bool :type sync: bool
:param print_module: whether to print the emitted C++ code :param print_module: whether to print the emitted C++ code
:type print_module: bool :type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel :return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmArguments :rtype: cutlass.backend.GemmArguments
@ -687,6 +691,7 @@ class Gemm(OperationBase):
'D': self._get_batch_stride(D) 'D': self._get_batch_stride(D)
} }
} }
kwargs['stream'] = stream
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor): if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
output_op = self.operation.epilogue_type(visitor_args) output_op = self.operation.epilogue_type(visitor_args)

View File

@ -65,6 +65,7 @@ from cutlass.backend.library import (
from cutlass.op.gemm import Gemm from cutlass.op.gemm import Gemm
from cutlass.shape import GemmCoord from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes from cutlass.utils import check, datatypes
from cuda import cuda
class GroupedGemm(Gemm): class GroupedGemm(Gemm):
@ -194,7 +195,8 @@ class GroupedGemm(Gemm):
def run(self, A, B, C, D, def run(self, A, B, C, D,
alpha=None, beta=None, sync: bool = True, alpha=None, beta=None, sync: bool = True,
print_module: bool = False) -> GemmGroupedArguments: print_module: bool = False,
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments:
""" """
Runs the kernel currently specified. Runs the kernel currently specified.
@ -217,6 +219,8 @@ class GroupedGemm(Gemm):
:type sync: bool :type sync: bool
:param print_module: whether to print the emitted C++ code :param print_module: whether to print the emitted C++ code
:type print_module: bool :type print_module: bool
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
:type stream: :class:`cuda.cuda.CUstream`
:return: arguments passed in to the kernel :return: arguments passed in to the kernel
:rtype: cutlass.backend.GemmGroupedArguments :rtype: cutlass.backend.GemmGroupedArguments
@ -248,7 +252,8 @@ class GroupedGemm(Gemm):
operation=self.operation, operation=self.operation,
problem_sizes=problem_sizes, problem_sizes=problem_sizes,
A=As, B=Bs, C=Cs, D=Ds, A=As, B=Bs, C=Cs, D=Ds,
output_op=self.operation.epilogue_type(alpha, beta) output_op=self.operation.epilogue_type(alpha, beta),
stream=stream
) )
self.operation.run(arguments) self.operation.run(arguments)