add maximum support (#1833)
This commit is contained in:
parent
d65266a868
commit
f3a3bfcbf2
@ -70,6 +70,8 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor):
|
|||||||
ast.Sub: FunctionalOp.Minus,
|
ast.Sub: FunctionalOp.Minus,
|
||||||
ast.Mult: FunctionalOp.Multiplies,
|
ast.Mult: FunctionalOp.Multiplies,
|
||||||
ast.Div: FunctionalOp.Divides,
|
ast.Div: FunctionalOp.Divides,
|
||||||
|
"maximum": FunctionalOp.Maximum,
|
||||||
|
"minimum": FunctionalOp.Minimum,
|
||||||
"relu": relu.binding_type,
|
"relu": relu.binding_type,
|
||||||
"multiply_add": FunctionalOp.MultiplyAdd,
|
"multiply_add": FunctionalOp.MultiplyAdd,
|
||||||
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
"sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd),
|
||||||
|
|||||||
@ -49,5 +49,7 @@ from cutlass.epilogue.evt_ops import (
|
|||||||
multiply_add,
|
multiply_add,
|
||||||
sum,
|
sum,
|
||||||
permute,
|
permute,
|
||||||
reshape
|
reshape,
|
||||||
|
maximum,
|
||||||
|
minimum,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -59,6 +59,17 @@ def max(x, dim):
|
|||||||
elif is_torch_tensor(x):
|
elif is_torch_tensor(x):
|
||||||
return torch.amax(x, dim)
|
return torch.amax(x, dim)
|
||||||
|
|
||||||
|
def maximum(x, y):
|
||||||
|
if is_numpy_tensor(x):
|
||||||
|
return np.maximum(x, y)
|
||||||
|
elif is_torch_tensor(x):
|
||||||
|
return torch.maximum(x, torch.tensor(y))
|
||||||
|
|
||||||
|
def minimum(x, y):
|
||||||
|
if is_numpy_tensor(x):
|
||||||
|
return np.minimum(x, y)
|
||||||
|
elif is_torch_tensor(x):
|
||||||
|
return torch.minimum(x, torch.tensor(y))
|
||||||
|
|
||||||
##############################################################################
|
##############################################################################
|
||||||
# Layout manipulate nodes
|
# Layout manipulate nodes
|
||||||
|
|||||||
@ -95,6 +95,29 @@ class TestEVTCompute(EVTTestCaseBase):
|
|||||||
result_keys = ["D"]
|
result_keys = ["D"]
|
||||||
launcher.verify((m, n, k), input_keys, result_keys, l)
|
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||||
|
|
||||||
|
def test_func_call2(self):
|
||||||
|
"""
|
||||||
|
Test Function call
|
||||||
|
"""
|
||||||
|
|
||||||
|
def evt_func_call2(accum, C, alpha, beta):
|
||||||
|
D = maximum(alpha * accum + beta * C, 0.0)
|
||||||
|
return D
|
||||||
|
|
||||||
|
for m, n, k, l in self.get_problem_sizes(8):
|
||||||
|
example_inputs = {
|
||||||
|
"accum": self.fake_tensor(self.element, (l, m, n)),
|
||||||
|
"C": self.fake_tensor(self.element, (l, m, n)),
|
||||||
|
"alpha": 1.5,
|
||||||
|
"beta": 0.5,
|
||||||
|
"D": self.fake_tensor(self.element, (l, m, n))
|
||||||
|
}
|
||||||
|
|
||||||
|
launcher = EVTTestBed(self.element, evt_func_call2, example_inputs)
|
||||||
|
input_keys = ["C", "alpha", "beta"]
|
||||||
|
result_keys = ["D"]
|
||||||
|
launcher.verify((m, n, k), input_keys, result_keys, l)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user