torch_ext/test.py

11 lines
217 B
Python
Raw Normal View History

2024-11-16 19:26:54 +08:00
import torch
import matrix_add
a = torch.randn(4, 4, device="cuda").half()
b = torch.randn(4, 4, device="cuda").half()
c = torch.empty_like(a)
matrix_add.matrix_add(a, b, c)
assert torch.allclose(a + b, c)
print(c)