11 lines
217 B
Python
11 lines
217 B
Python
import torch
|
|
import matrix_add
|
|
|
|
a = torch.randn(4, 4, device="cuda").half()
|
|
b = torch.randn(4, 4, device="cuda").half()
|
|
c = torch.empty_like(a)
|
|
|
|
matrix_add.matrix_add(a, b, c)
|
|
assert torch.allclose(a + b, c)
|
|
print(c)
|