diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index faffce65..3f334854 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -70,6 +70,8 @@ class PythonASTFrontend(EVTFrontendBase, ast.NodeVisitor): ast.Sub: FunctionalOp.Minus, ast.Mult: FunctionalOp.Multiplies, ast.Div: FunctionalOp.Divides, + "maximum": FunctionalOp.Maximum, + "minimum": FunctionalOp.Minimum, "relu": relu.binding_type, "multiply_add": FunctionalOp.MultiplyAdd, "sum": (FunctionalOp.Plus, FunctionalOp.AtomicAdd), diff --git a/python/cutlass/epilogue/__init__.py b/python/cutlass/epilogue/__init__.py index 2b22b5f5..423decce 100644 --- a/python/cutlass/epilogue/__init__.py +++ b/python/cutlass/epilogue/__init__.py @@ -49,5 +49,7 @@ from cutlass.epilogue.evt_ops import ( multiply_add, sum, permute, - reshape + reshape, + maximum, + minimum, ) diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index a9b9b5bf..575767d0 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -59,6 +59,17 @@ def max(x, dim): elif is_torch_tensor(x): return torch.amax(x, dim) +def maximum(x, y): + if is_numpy_tensor(x): + return np.maximum(x, y) + elif is_torch_tensor(x): + return torch.maximum(x, torch.tensor(y)) + +def minimum(x, y): + if is_numpy_tensor(x): + return np.minimum(x, y) + elif is_torch_tensor(x): + return torch.minimum(x, torch.tensor(y)) ############################################################################## # Layout manipulate nodes diff --git a/test/python/cutlass/evt/evt_compute_sm80_90.py b/test/python/cutlass/evt/evt_compute_sm80_90.py index 36cee787..3f9996cf 100644 --- a/test/python/cutlass/evt/evt_compute_sm80_90.py +++ b/test/python/cutlass/evt/evt_compute_sm80_90.py @@ -95,6 +95,29 @@ class TestEVTCompute(EVTTestCaseBase): result_keys = ["D"] launcher.verify((m, n, k), input_keys, result_keys, l) + def test_func_call2(self): + """ + Test Function call + """ + + def evt_func_call2(accum, C, alpha, beta): + D = maximum(alpha * accum + beta * C, 0.0) + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.5, + "beta": 0.5, + "D": self.fake_tensor(self.element, (l, m, n)) + } + + launcher = EVTTestBed(self.element, evt_func_call2, example_inputs) + input_keys = ["C", "alpha", "beta"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + if __name__ == '__main__': unittest.main()