torch_ext/test_reducemax.py

34 lines
720 B
Python

import torch
import torch_cuda_ext.core as core
n = 1000000
for i in range(1000):
src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda()
core.reducemax(src, dest)
print(dest[0])
print(src.sum())
core.test_cute_tensor()
src = torch.randn(size=(4096, 4096)).float().cuda()
dest = torch.zeros(size=(4096,)).float().cuda()
core.block_sum(src, dest)
src = src * src
real_sum = src.sum(dim=1)
diff = real_sum - dest
print(diff)
src = torch.randn(size=((64, 128, 4096))).float().cuda()
dest = torch.randn(size=(64, 128)).float().cuda()
core.md_block_sum(src, dest)
real_sum = src.sum(dim=-1)
diff = real_sum - dest
print(diff)