torch_ext/test_reducemax.py
2024-11-16 19:26:54 +08:00

15 lines
313 B
Python

import torch
import torch_cuda_ext.core as core
n = 1000000
for i in range(1000):
src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda()
core.reducemax(src, dest)
print(dest[0])
print(src.sum())
core.test_cute_tensor()