torch_ext/test.py
2024-11-16 19:26:54 +08:00

11 lines
217 B
Python

import torch
import matrix_add
a = torch.randn(4, 4, device="cuda").half()
b = torch.randn(4, 4, device="cuda").half()
c = torch.empty_like(a)
matrix_add.matrix_add(a, b, c)
assert torch.allclose(a + b, c)
print(c)