expose stream API in python kernel call interfaces (#1287)
* expose stream API in python kernel call interfaces * add stream to ReductionArguments; document stream arg * add stream argument to GemmGroupedArguments
This commit is contained in:
parent
d4be5ab5d7
commit
8ac2edc810
@ -62,6 +62,11 @@ class ArgumentBase:
|
|||||||
# by default, tensor_C is not bias
|
# by default, tensor_C is not bias
|
||||||
self.bias = False
|
self.bias = False
|
||||||
|
|
||||||
|
if "stream" in kwargs.keys():
|
||||||
|
self.stream = kwargs["stream"]
|
||||||
|
else:
|
||||||
|
self.stream = cuda.CUstream(0)
|
||||||
|
|
||||||
# RMM buffers used to track tensor lifetime
|
# RMM buffers used to track tensor lifetime
|
||||||
self.buffers = {}
|
self.buffers = {}
|
||||||
# Host tensor to copy the computed result back
|
# Host tensor to copy the computed result back
|
||||||
|
@ -97,6 +97,8 @@ class Conv2dArguments(ArgumentBase):
|
|||||||
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
:type split_k_mode: cutlass_library.library.SplitKMode, optional
|
||||||
:param output_op: output operator, optional
|
:param output_op: output operator, optional
|
||||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, operation, problem_size, A, B, C, D,
|
def __init__(self, operation, problem_size, A, B, C, D,
|
||||||
@ -448,6 +450,7 @@ class Conv2dOperation:
|
|||||||
arguments.host_workspace,
|
arguments.host_workspace,
|
||||||
arguments.device_workspace,
|
arguments.device_workspace,
|
||||||
arguments.launch_config,
|
arguments.launch_config,
|
||||||
|
arguments.stream
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||||
|
@ -164,6 +164,9 @@ class GemmArguments2x(ArgumentBase):
|
|||||||
|
|
||||||
:param output_op: output operator, optional
|
:param output_op: output operator, optional
|
||||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||||
|
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
def __init__(self, operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMode.Gemm, **kwargs):
|
||||||
@ -666,6 +669,9 @@ class GemmGroupedArguments:
|
|||||||
|
|
||||||
:param output_op: output operator, optional
|
:param output_op: output operator, optional
|
||||||
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
:type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments`
|
||||||
|
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
|
def __init__(self, operation, problem_sizes, A, B, C, D, **kwargs):
|
||||||
@ -765,6 +771,11 @@ class GemmGroupedArguments:
|
|||||||
self.output_op = kwargs["output_op"]
|
self.output_op = kwargs["output_op"]
|
||||||
else:
|
else:
|
||||||
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
self.output_op = self.operation.epilogue_type(1.0, 0.0)
|
||||||
|
|
||||||
|
if "stream" in kwargs.keys():
|
||||||
|
self.stream = kwargs["stream"]
|
||||||
|
else:
|
||||||
|
self.stream = cuda.CUstream(0)
|
||||||
|
|
||||||
# Get host problem size
|
# Get host problem size
|
||||||
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
self.host_problem_size_ptr = np.array(problem_size_host, dtype=np.int32).__array_interface__["data"][0]
|
||||||
@ -1542,6 +1553,7 @@ class GemmOperationBase:
|
|||||||
arguments.host_workspace,
|
arguments.host_workspace,
|
||||||
arguments.device_workspace,
|
arguments.device_workspace,
|
||||||
arguments.launch_config,
|
arguments.launch_config,
|
||||||
|
arguments.stream
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||||
|
@ -79,6 +79,10 @@ class ReductionArguments:
|
|||||||
else:
|
else:
|
||||||
# by default, tensor_C is not bias
|
# by default, tensor_C is not bias
|
||||||
self.bias = False
|
self.bias = False
|
||||||
|
if "stream" in kwargs.keys():
|
||||||
|
self.stream = kwargs["stream"]
|
||||||
|
else:
|
||||||
|
self.stream = cuda.CUstream(0)
|
||||||
|
|
||||||
self.operation = operation
|
self.operation = operation
|
||||||
self.ptr_workspace = workspace
|
self.ptr_workspace = workspace
|
||||||
@ -386,6 +390,7 @@ class ReductionOperation:
|
|||||||
host_workspace,
|
host_workspace,
|
||||||
device_workspace,
|
device_workspace,
|
||||||
launch_config,
|
launch_config,
|
||||||
|
arguments.stream
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != cuda.CUresult.CUDA_SUCCESS:
|
if err != cuda.CUresult.CUDA_SUCCESS:
|
||||||
|
@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
|||||||
from cutlass.op.op import OperationBase
|
from cutlass.op.op import OperationBase
|
||||||
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
from cutlass.shape import Conv2DProblemSize, MatrixCoord
|
||||||
from cutlass.utils import check, datatypes
|
from cutlass.utils import check, datatypes
|
||||||
|
from cuda import cuda
|
||||||
|
|
||||||
|
|
||||||
class Conv2d(OperationBase):
|
class Conv2d(OperationBase):
|
||||||
@ -733,7 +734,8 @@ class Conv2d(OperationBase):
|
|||||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
stride=(1, 1), padding=(0, 0), dilation=(1, 1),
|
||||||
alpha=None, beta=None,
|
alpha=None, beta=None,
|
||||||
split_k=("serial", 1), sync: bool = True,
|
split_k=("serial", 1), sync: bool = True,
|
||||||
print_module: bool = False) -> Conv2dArguments:
|
print_module: bool = False,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||||
"""
|
"""
|
||||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||||
@ -760,6 +762,8 @@ class Conv2d(OperationBase):
|
|||||||
:type sync: bool
|
:type sync: bool
|
||||||
:param print_module: whether to print the emitted C++ code
|
:param print_module: whether to print the emitted C++ code
|
||||||
:type print_module: bool
|
:type print_module: bool
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
|
|
||||||
:return: arguments passed in to the kernel
|
:return: arguments passed in to the kernel
|
||||||
:rtype: cutlass.backend.Conv2dArguments
|
:rtype: cutlass.backend.Conv2dArguments
|
||||||
@ -850,7 +854,8 @@ class Conv2d(OperationBase):
|
|||||||
A=A, B=B, C=C, D=D,
|
A=A, B=B, C=C, D=D,
|
||||||
output_op=self.operation.epilogue_type(*epilogue_args),
|
output_op=self.operation.epilogue_type(*epilogue_args),
|
||||||
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
|
||||||
split_k_slices=split_k[1]
|
split_k_slices=split_k[1],
|
||||||
|
stream=stream
|
||||||
)
|
)
|
||||||
|
|
||||||
self.operation.run(arguments)
|
self.operation.run(arguments)
|
||||||
@ -864,7 +869,8 @@ class Conv2d(OperationBase):
|
|||||||
workspace=arguments.ptr_D,
|
workspace=arguments.ptr_D,
|
||||||
destination=D,
|
destination=D,
|
||||||
source=C,
|
source=C,
|
||||||
output_op=self.reduction_operation.epilogue_type(*epilogue_args)
|
output_op=self.reduction_operation.epilogue_type(*epilogue_args),
|
||||||
|
stream=stream
|
||||||
)
|
)
|
||||||
self.reduction_operation.run(reduction_arguments)
|
self.reduction_operation.run(reduction_arguments)
|
||||||
|
|
||||||
@ -919,11 +925,12 @@ class Conv2dFprop(Conv2d):
|
|||||||
def run(
|
def run(
|
||||||
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
self, input=None, weight=None, C=None, output=None, alpha=None, beta=None,
|
||||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
sync: bool = True, print_module: bool = False,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||||
|
|
||||||
A, B, D = input, weight, output
|
A, B, D = input, weight, output
|
||||||
return super().run(
|
return super().run(
|
||||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dDgrad(Conv2d):
|
class Conv2dDgrad(Conv2d):
|
||||||
@ -943,11 +950,12 @@ class Conv2dDgrad(Conv2d):
|
|||||||
|
|
||||||
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None,
|
||||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
sync: bool = True, print_module: bool = False,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||||
#
|
#
|
||||||
A, B, D = grad_output, weight, grad_input
|
A, B, D = grad_output, weight, grad_input
|
||||||
return super().run(
|
return super().run(
|
||||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||||
|
|
||||||
|
|
||||||
class Conv2dWgrad(Conv2d):
|
class Conv2dWgrad(Conv2d):
|
||||||
@ -967,8 +975,9 @@ class Conv2dWgrad(Conv2d):
|
|||||||
|
|
||||||
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None,
|
||||||
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
|
||||||
sync: bool = True, print_module: bool = False) -> Conv2dArguments:
|
sync: bool = True, print_module: bool = False,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
|
||||||
#
|
#
|
||||||
A, B, D = grad_output, input, grad_weight
|
A, B, D = grad_output, input, grad_weight
|
||||||
return super().run(
|
return super().run(
|
||||||
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module)
|
A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)
|
||||||
|
@ -131,6 +131,7 @@ from cutlass.backend.library import TensorDescription, TileDescription
|
|||||||
from cutlass.op.op import OperationBase
|
from cutlass.op.op import OperationBase
|
||||||
from cutlass.shape import GemmCoord
|
from cutlass.shape import GemmCoord
|
||||||
from cutlass.utils import check, datatypes
|
from cutlass.utils import check, datatypes
|
||||||
|
from cuda import cuda
|
||||||
|
|
||||||
|
|
||||||
class Gemm(OperationBase):
|
class Gemm(OperationBase):
|
||||||
@ -621,7 +622,8 @@ class Gemm(OperationBase):
|
|||||||
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
f'layout of ({ref_type}, {ref_layout}) and transpose failed.')
|
||||||
|
|
||||||
def run(self, A=None, B=None, C=None, D=None,
|
def run(self, A=None, B=None, C=None, D=None,
|
||||||
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None) -> GemmArguments:
|
alpha=None, beta=None, sync: bool = True, print_module: bool = False, visitor_args: dict = None,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmArguments:
|
||||||
"""
|
"""
|
||||||
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
Runs the kernel currently specified. If it has not already been, the kernel is emitted and
|
||||||
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
compiled. Tensors holding operands and outputs of the kernel are sourced either from the
|
||||||
@ -644,6 +646,8 @@ class Gemm(OperationBase):
|
|||||||
:type sync: bool
|
:type sync: bool
|
||||||
:param print_module: whether to print the emitted C++ code
|
:param print_module: whether to print the emitted C++ code
|
||||||
:type print_module: bool
|
:type print_module: bool
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
|
|
||||||
:return: arguments passed in to the kernel
|
:return: arguments passed in to the kernel
|
||||||
:rtype: cutlass.backend.GemmArguments
|
:rtype: cutlass.backend.GemmArguments
|
||||||
@ -687,6 +691,7 @@ class Gemm(OperationBase):
|
|||||||
'D': self._get_batch_stride(D)
|
'D': self._get_batch_stride(D)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
kwargs['stream'] = stream
|
||||||
|
|
||||||
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
if isinstance(self.epilogue_functor, EpilogueFunctorVisitor):
|
||||||
output_op = self.operation.epilogue_type(visitor_args)
|
output_op = self.operation.epilogue_type(visitor_args)
|
||||||
|
@ -65,6 +65,7 @@ from cutlass.backend.library import (
|
|||||||
from cutlass.op.gemm import Gemm
|
from cutlass.op.gemm import Gemm
|
||||||
from cutlass.shape import GemmCoord
|
from cutlass.shape import GemmCoord
|
||||||
from cutlass.utils import check, datatypes
|
from cutlass.utils import check, datatypes
|
||||||
|
from cuda import cuda
|
||||||
|
|
||||||
|
|
||||||
class GroupedGemm(Gemm):
|
class GroupedGemm(Gemm):
|
||||||
@ -194,7 +195,8 @@ class GroupedGemm(Gemm):
|
|||||||
|
|
||||||
def run(self, A, B, C, D,
|
def run(self, A, B, C, D,
|
||||||
alpha=None, beta=None, sync: bool = True,
|
alpha=None, beta=None, sync: bool = True,
|
||||||
print_module: bool = False) -> GemmGroupedArguments:
|
print_module: bool = False,
|
||||||
|
stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments:
|
||||||
"""
|
"""
|
||||||
Runs the kernel currently specified.
|
Runs the kernel currently specified.
|
||||||
|
|
||||||
@ -217,6 +219,8 @@ class GroupedGemm(Gemm):
|
|||||||
:type sync: bool
|
:type sync: bool
|
||||||
:param print_module: whether to print the emitted C++ code
|
:param print_module: whether to print the emitted C++ code
|
||||||
:type print_module: bool
|
:type print_module: bool
|
||||||
|
:param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
|
||||||
|
:type stream: :class:`cuda.cuda.CUstream`
|
||||||
|
|
||||||
:return: arguments passed in to the kernel
|
:return: arguments passed in to the kernel
|
||||||
:rtype: cutlass.backend.GemmGroupedArguments
|
:rtype: cutlass.backend.GemmGroupedArguments
|
||||||
@ -248,7 +252,8 @@ class GroupedGemm(Gemm):
|
|||||||
operation=self.operation,
|
operation=self.operation,
|
||||||
problem_sizes=problem_sizes,
|
problem_sizes=problem_sizes,
|
||||||
A=As, B=Bs, C=Cs, D=Ds,
|
A=As, B=Bs, C=Cs, D=Ds,
|
||||||
output_op=self.operation.epilogue_type(alpha, beta)
|
output_op=self.operation.epilogue_type(alpha, beta),
|
||||||
|
stream=stream
|
||||||
)
|
)
|
||||||
|
|
||||||
self.operation.run(arguments)
|
self.operation.run(arguments)
|
||||||
|
Loading…
Reference in New Issue
Block a user