cutlass/python/cutlass/backend/parser.py
ANIKET SHIVAM d572cc1aab
CUTLASS 3.1 (#915)
Co-authored-by: Aniket Shivam <ashivam@nvidia.com>
2023-04-14 23:19:34 -04:00

878 lines
29 KiB
Python

################################################################################
#
# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
################################################################################
import ast
import ctypes
import inspect
import textwrap
from typing import Generic, TypeVar
from cuda import cuda, cudart
import numpy as np
from treelib import Tree
from cutlass.backend.epilogue import (
AccumulatorOp,
BinaryOp,
ColumnBroadcastOp,
ColumnReductionOp,
RowBroadcastOp,
RowReductionOp,
TensorInputOp,
TensorOutputOp,
UnaryOp,
)
from cutlass.backend.frontend import NumpyFrontend
from cutlass.backend.utils.software import SubstituteTemplate
import cutlass.backend as backend
################################################################################
# Type annotation for input arguments
################################################################################
Ttype = TypeVar("Ttype")
Dtype = TypeVar("Dtype")
class NDArray(np.ndarray, Generic[Ttype, Dtype]):
pass
################################################################################
# Operations
################################################################################
operators = {
ast.Add: "Add",
ast.Div: "Div",
ast.Eq: "Equal",
ast.Mult: "Mult",
}
################################################################################
# AST Node abstractions
################################################################################
class UnaryNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
args,
) -> None:
if isinstance(node, BinOpNode):
self.op = node.op
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
self.op = node.func.id
elif isinstance(node.func, ast.Attribute):
self.op = node.func.value.id
else:
raise TypeError
else:
raise TypeError
self.tag = "Unary" + self.op + str(UnaryNode.cnt)
self.id = self.op + str(UnaryNode.cnt)
self.args = args
UnaryNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = UnaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
epilogue_ops = []
for arg in self.args:
try:
epilogue_ops.append(kwargs[arg])
except:
epilogue_ops.append(arg) # direct arguments like constant
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(*epilogue_ops),
*visitor_args,
)
class BinOpNode:
cnt = 0
# Concept: this is created by the BinOp Node in python ast
def __init__(
self,
element_accumulator,
element_compute,
elements_per_access,
node,
) -> None:
self.op = operators[type(node.op)]
self.tag = "Binary" + self.op + str(BinOpNode.cnt)
self.id = self.op + str(BinOpNode.cnt)
self.args = None
BinOpNode.cnt += 1
self.type = "tensor"
self.epilogue_op = getattr(backend, "Vector" + self.op)(element_compute)
# data types
self.element_accumulator = element_accumulator
self.element_compute = element_compute
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = BinaryOp(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
*visitors,
self.epilogue_op,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
self.epilogue_op.argument_type(self.args),
*visitor_args,
)
class NameNode:
# Concept: this is created by the Name Node in python ast
def __init__(self, node) -> None:
try:
self.id = node.id
except:
self.id = node.targets[0].id
self.tag = self.id
class ScalarInputNode(NameNode):
# Concept: scalar
def __init__(self, node) -> None:
super().__init__(node)
self.tag = "Scalar:" + self.tag
self.type = "scalar"
class AccumulatorNode(NameNode):
# Concept: VisitorOpAccumulator
def __init__(
self,
element_accumulator,
elements_per_access,
node,
) -> None:
super().__init__(node)
self.tag = "Accum:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
def get_epilogue_node(self, visitors):
self.epilogue_node = AccumulatorOp(
self.element_accumulator,
self.elements_per_access,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type()
class TensorInputNode(NameNode):
# Concept: VisitorOpTensorInput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorInput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, *args):
self.epilogue_node = TensorInputOp(self.element_accumulator)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowBroadcastNode(NameNode):
# Concept: VisitorOpRowBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
#
self.tag = "RowBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = RowBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
)
class ColumnBroadcastNode(NameNode):
# Concept: VisitorOpColumnBroadcast
def __init__(
self,
element_accumulator,
element_fragment,
node,
) -> None:
super().__init__(node)
self.tag = "ColumnBroadcast:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_fragment = element_fragment
def get_epilogue_node(self, *args):
self.epilogue_node = ColumnBroadcastOp(
self.element_accumulator,
self.element_fragment,
)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][0],
)
class TensorOutputNode(NameNode):
# Concept: VisitorOpTensorOutput
def __init__(self, element_accumulator, node) -> None:
super().__init__(node)
self.tag = "TensorOutput:" + self.tag
self.type = "tensor"
self.element_accumulator = element_accumulator
def get_epilogue_node(self, visitors):
self.epilogue_node = TensorOutputOp(self.element_accumulator, *visitors)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
kwargs["problem_size"][1],
*visitor_args,
kwargs["problem_size"][0] * kwargs["problem_size"][1],
)
class RowReductionNode:
# Concept: RowReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "RowReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = RowReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[0] * ((problem_size[1] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
class ColumnReductionNode:
# Concept: ColumnReductionOp
def __init__(
self,
element_accumulator,
element_reduction,
element_reduction_accumulator,
id,
factor,
) -> None:
#
self.id = id
self.tag = "ColumnReduction:" + self.id
self.type = "tensor"
self.element_accumulator = element_accumulator
self.element_reduction = element_reduction
self.element_reduction_accumulator = element_reduction_accumulator
self.factor = factor
def get_epilogue_node(self, visitors):
self.epilogue_node = ColumnReductionOp(
self.element_accumulator,
self.element_reduction,
self.element_reduction_accumulator,
*visitors,
)
def get_batch_stride(self, problem_size):
return problem_size[1] * ((problem_size[0] + self.factor - 1) // self.factor)
def get_argument(self, visitor_args, kwargs):
self.argument = self.epilogue_node.argument_type(
kwargs[self.id + "_ptr"],
*visitor_args,
self.get_batch_stride(kwargs["problem_size"]),
)
################################################################################
# Epilogue parser function
################################################################################
class EpilogueAST(ast.NodeVisitor):
def __init__(
self,
epilogue,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.epilogue = epilogue
self.source = textwrap.dedent(inspect.getsource(epilogue.__call__))
self.ast_tree = ast.parse(self.source)
self.epilogue_tree = Tree()
# print(ast.dump(self.ast_tree, indent=4)) # For Debug purpose
# input arguments
self.input_args = {}
# return nodes
self.returns = []
# reduction source nodes
self.reduction_source = {}
# stack used to keep the parent node id
self.stack = []
# visit the AST
self.visit(self.ast_tree)
# visit the name node
def visit_Name(self, node):
# append the return ids into self.returns
if self.stack[-1] == "return":
self.returns.append(node.id)
else:
# accum is produced from accumulator node
if node.id == "accum":
name_node = AccumulatorNode(
self.element_accumulator,
self.elements_per_access,
node,
)
else:
# for input nodes
if node.id in self.input_args.keys():
type = self.input_args[node.id][0]
if type == "tensor":
name_node = TensorInputNode(
self.element_accumulator,
node,
)
elif type == "row":
name_node = RowBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "column":
name_node = ColumnBroadcastNode(
self.element_accumulator,
self.element_compute,
node,
)
elif type == "scalar":
name_node = ScalarInputNode(node)
else:
raise ValueError(type)
# for output nodes
else:
name_node = TensorOutputNode(
self.element_accumulator,
node,
)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
parent=self.stack[-1],
)
def visit_Assign(self, node):
pre_assign_node = self.epilogue_tree.get_node(node.targets[0].id)
if pre_assign_node is None:
# The assign is to a root node
# skip the reduction nodes
if isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name):
func_type = node.value.func.id
elif isinstance(node.value.func, ast.Attribute):
func_type = node.value.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.reduction_source[node.value.args[0].id] = [
node.value.args[1].value,
node.value.args[2].value,
node.targets[0].id,
]
return
name_node = TensorOutputNode(self.element_accumulator, node)
self.epilogue_tree.create_node(
name_node.tag,
name_node.id,
data=name_node,
)
self.stack.append(name_node.id)
else:
if (
node.targets[0].id in self.returns
or node.targets[0].id in self.reduction_source.keys()
):
self.stack.append(node.targets[0].id)
else:
self.stack.append(
pre_assign_node.predecessor(self.epilogue_tree.identifier)
)
self.epilogue_tree.remove_node(node.targets[0].id)
# get child tag
self.visit(node.value)
self.stack.pop()
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
func_type = node.func.id
elif isinstance(node.func, ast.Attribute):
func_type = node.func.value.id
else:
raise TypeError
if func_type == "reduction_op":
self.visit(node.args[0])
else:
arg_list = []
for idx, arg in enumerate(node.args):
if idx == 0:
continue
if isinstance(arg, ast.Constant):
arg_list.append(arg.value)
elif isinstance(arg, ast.Name):
arg_list.append(arg.id)
else:
raise TypeError
unary_node = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
arg_list,
)
self.epilogue_tree.create_node(
unary_node.tag,
unary_node.id,
parent=self.stack[-1],
data=unary_node,
)
self.stack.append(unary_node.id)
self.visit(node.args[0])
self.stack.pop()
def visit_BinOp(self, node):
binop = BinOpNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node,
)
self.epilogue_tree.create_node(
binop.tag,
binop.id,
data=binop,
parent=self.stack[-1],
)
self.stack.append(binop.id)
self.visit(node.left)
self.visit(node.right)
self.stack.pop()
def visit_Return(self, node):
self.stack.append("return")
self.visit(node.value)
self.stack.pop()
# # A function definition
def visit_FunctionDef(self, node: ast.FunctionDef):
# visit args
for arg in node.args.args:
if arg.arg == "self":
continue
if isinstance(arg.annotation, ast.Constant):
self.input_args[arg.arg] = [
arg.annotation.value,
]
# visit the assign in the reverse order
for idx in range(len(node.body)):
self.visit(node.body[-1 - idx])
#
# Tree optimization pass
#
# pass 1: lower Binary to Unary
def pass_binary_2_unary(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, BinOpNode):
lhs_node = tree.get_node(node.successors(tree.identifier)[0])
left_type = lhs_node.data.type
rhs_node = tree.get_node(node.successors(tree.identifier)[1])
right_type = rhs_node.data.type
if left_type == "scalar" and right_type == "tensor":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
lhs_node.data.id,
],
)
node.tag = node.data.tag
tree.remove_node(lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
elif left_type == "tensor" and right_type == "scalar":
node.data = UnaryNode(
self.element_accumulator,
self.element_compute,
self.elements_per_access,
node.data,
[
rhs_node.id,
],
)
node.tag = node.data.tag
tree.remove_node(rhs_node.data.id)
self.pass_binary_2_unary(tree, lhs_node.data.id)
else:
self.pass_binary_2_unary(tree, lhs_node.data.id)
self.pass_binary_2_unary(tree, rhs_node.data.id)
else:
for child in node.successors(tree.identifier):
self.pass_binary_2_unary(tree, child)
# pass 2: inject reduction nodes
def pass_inject_reduction(self, tree, nid):
node = tree.get_node(nid)
if isinstance(node.data, TensorOutputNode):
if node.data.id in self.reduction_source.keys():
direction = self.reduction_source[node.data.id][0]
target = self.reduction_source[node.data.id][-1]
if direction == "row":
reduction_node = RowReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[1],
)
elif direction == "column":
reduction_node = ColumnReductionNode(
self.element_accumulator,
self.element_output,
self.element_accumulator,
target,
self.tile_description.threadblock_shape[0],
)
else:
raise ValueError(direction)
child_nid = node.successors(tree.identifier)[0]
# if this output node is injected only for reduction
if node.data.id not in self.returns:
# get reduction config from disc
node.data = reduction_node
node.tag = reduction_node.tag
self.pass_inject_reduction(tree, child_nid)
# if this output node is also a tensor output, inject reduction as its children
else:
# get child node
tree.create_node(
reduction_node.tag,
reduction_node.id,
data=reduction_node,
parent=node.data.id,
)
tree.move_node(
child_nid,
reduction_node.id,
)
child = tree.get_node(child_nid)
for grand_child in child.successors(tree.identifier):
self.pass_inject_reduction(tree, grand_child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
else:
for child in node.successors(tree.identifier):
self.pass_inject_reduction(tree, child)
def pass_inject_epilogue_op(self, tree, nid):
node = tree.get_node(nid)
visitors = []
for child in node.successors(tree.identifier):
visitors.append(self.pass_inject_epilogue_op(tree, child))
node.data.get_epilogue_node(visitors)
return node.data.epilogue_node
def get_arguments(self, tree, nid, kwargs):
node = tree.get_node(nid)
visitor_args = []
for child in node.successors(tree.identifier):
visitor_args.append(self.get_arguments(tree, child, kwargs))
node.data.get_argument(visitor_args, kwargs)
return node.data.argument
class EpilogueVisitTree:
KernelTemplate = """
${visitor}
using ${operation_name}_EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorGeneric<${visitor_name}>;
"""
def __init__(
self,
elementwise_functor,
tile_description,
element_accumulator,
elements_per_access,
element_compute,
element_output,
) -> None:
#
# data types
self.tile_description = tile_description
self.element_accumulator = element_accumulator
self.elements_per_access = elements_per_access
self.element_compute = element_compute
self.element_output = element_output
self.elementwise_functor = elementwise_functor
pass
def initialize(self):
function = EpilogueAST(
self,
self.tile_description,
self.element_accumulator,
self.elements_per_access,
self.element_compute,
self.element_output,
)
#
tree = function.epilogue_tree
self.tree = tree
function.pass_binary_2_unary(self.tree, self.tree.root)
function.pass_inject_reduction(self.tree, self.tree.root)
function.pass_inject_epilogue_op(self.tree, self.tree.root)
visitor = self.tree.get_node(self.tree.root).data.epilogue_node
self.visitor = visitor
class _Argument(ctypes.Structure):
_fields_ = [
(
"visitor_arg",
visitor.argument_type,
)
]
def __init__(self, **kwargs) -> None:
# process input args
_kwargs = {}
for input_key in function.input_args.keys():
if input_key == "accum":
continue
if function.input_args[input_key][0] == "scalar":
continue
# tensor input
else:
setattr(
self,
"buffer_tensor_" + input_key,
NumpyFrontend.argument(
kwargs[input_key],
False,
),
)
setattr(
self,
input_key + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + input_key,
).ptr
),
)
_kwargs[input_key + "_ptr"] = getattr(
self,
input_key + "_ptr",
)
# process the return args
for ret in function.returns:
setattr(
self,
"buffer_tensor_" + ret,
NumpyFrontend.argument(kwargs[ret], True),
)
setattr(
self,
ret + "_ptr",
int(
getattr(
self,
"buffer_tensor_" + ret,
).ptr
),
)
_kwargs[ret + "_ptr"] = getattr(self, ret + "_ptr")
setattr(
self,
"host_tensor_" + ret,
kwargs[ret],
)
_kwargs.update(kwargs)
function.get_arguments(tree, tree.root, _kwargs)
self.visitor_arg = tree.get_node(tree.root).data.argument
def sync(self, stream_sync=True):
if stream_sync:
(err,) = cudart.cudaDeviceSynchronize()
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
for ret in function.returns:
(err,) = cuda.cuMemcpyDtoH(
getattr(
self,
"host_tensor_" + ret,
),
cuda.CUdeviceptr(getattr(self, ret + "_ptr")),
getattr(
self,
"host_tensor_" + ret,
).size
* getattr(
self,
"host_tensor_" + ret,
).itemsize,
)
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("CUDA Error %s" % str(err))
pass
self.epilogue_type = _Argument
def emit(self, operation):
values = {
"visitor": self.visitor.emit(operation),
"operation_name": operation.procedural_name(),
"visitor_name": self.visitor.instance_name,
}
return SubstituteTemplate(self.KernelTemplate, values)