import torch import matrix_add a = torch.randn(4, 4, device="cuda").half() b = torch.randn(4, 4, device="cuda").half() c = torch.empty_like(a) matrix_add.matrix_add(a, b, c) assert torch.allclose(a + b, c) print(c)