torch_ext/test_reducemax.py

79 lines
2.4 KiB
Python

import torch
import torch_cuda_ext.core as core
n = 1000000
for i in range(100):
src = torch.randn(size=(n,)).float().cuda()
dest_n = int((n + 1024 - 1) / 1024)
dest = torch.zeros(size=(dest_n,)).float().cuda()
core.reducemax(src, dest)
print(dest[0])
print(src.sum())
core.test_cute_tensor()
src = torch.randn(size=(4096, 4096)).float().cuda()
dest = torch.zeros(size=(4096,)).float().cuda()
core.block_sum(src, dest)
src = src * src
real_sum = src.sum(dim=1)
diff = real_sum - dest
print(diff)
src = torch.randn(size=((64, 128, 4096))).float().cuda()
dest = torch.randn(size=(64, 128)).float().cuda()
core.md_block_sum(src, dest)
real_sum = src.sum(dim=-1)
diff = real_sum - dest
print(diff)
for k in range(128, 4096, 128):
for j in range(1024, 4096, 1024):
a = torch.randn(size=(k, j)).half().cuda()
b = torch.empty_like(a)
num_runs = 100
times = []
for _ in range(num_runs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
core.softmax(a, b)
end.record()
torch.cuda.synchronize() # 等待 CUDA 操作完成
elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒
times.append(elapsed_time)
own_avg_time = sum(times) / num_runs
own_std_time = (sum((t - own_avg_time) ** 2 for t in times) / num_runs) ** 0.5
print(f"own softmax cost time: {own_avg_time}, {own_std_time}")
times = []
for _ in range(num_runs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
res = torch.softmax(a, dim=1)
end.record()
torch.cuda.synchronize() # 等待 CUDA 操作完成
elapsed_time = start.elapsed_time(end) / 1000 # 转换为秒
times.append(elapsed_time)
avg_time = sum(times) / num_runs
std_time = (sum((t - avg_time) ** 2 for t in times) / num_runs) ** 0.5
print(f"torch softmax cost time: {avg_time}, {std_time}")
# print("this is b", b)
diff = (res - b).abs().max()
if diff < 1e-4:
print("softmax is good")
time_diff_rate = (own_avg_time - avg_time) / avg_time
print(f"{k}, {j} matrix result {time_diff_rate}")
else:
print("softmax is not equal")