torch_ext/test_mm.py
2024-11-16 19:26:54 +08:00

88 lines
2.9 KiB
Python

# coding=utf-8
# import torch
# from torch_cuda_ext import core
# for i in range(1, 100):
# a = torch.randn(size=(1024, i)).float().cuda()
# b = torch.randn(size=(i, 1024)).float().cuda()
# c = torch.empty(size=(1024, 1024)).float().cuda()
# shared_c = torch.empty(size=(1024, 1024)).float().cuda()
# core.org_mm(a, b, c)
# real_c = torch.matmul(a, b)
# if not torch.allclose(real_c, c):
# print(i, torch.max(real_c - c))
# core.org_mm_shared(a, b, shared_c)
# if not torch.allclose(real_c, shared_c):
# print("shared mm:", i, torch.max(real_c - shared_c))
# torch.cuda.Event
import torch
from torch_cuda_ext import core
# 初始化CUDA事件
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
time_cost_result = []
for i in range(1, 1024):
a = torch.randn(size=(1024, i)).float().cuda()
b = torch.randn(size=(i, 1024)).float().cuda()
c = torch.empty(size=(1024, 1024)).float().cuda()
a_half = torch.randn(size=(1024, i)).half().cuda()
b_half = torch.randn(size=(i, 1024)).half().cuda()
c_half = torch.empty(size=(1024, 1024)).half().cuda()
shared_c = torch.empty(size=(1024, 1024)).float().cuda()
# 记录core.org_mm开始时间
start.record()
core.org_mm(a, b, c)
# 记录core.org_mm结束时间并计算耗时
end.record()
torch.cuda.synchronize() # 等待事件记录完成
print(f"org_mm with i={i}: {start.elapsed_time(end)} ms")
org_mm_time_cost = start.elapsed_time(end)
start.record()
real_c = torch.matmul(a, b)
end.record()
torch.cuda.synchronize()
print(f"real_mm with i={i}: {start.elapsed_time(end)} ms")
if not torch.allclose(real_c, c):
print(i, torch.max(real_c - c))
# 记录core.org_mm_shared开始时间
start.record()
core.org_mm_shared(a, b, shared_c)
# 记录core.org_mm_shared结束时间并计算耗时
end.record()
torch.cuda.synchronize() # 等待事件记录完成
print(f"org_mm_shared with i={i}: {start.elapsed_time(end)} ms")
shared_mm_time_cost = start.elapsed_time(end)
time_cost_result.append((org_mm_time_cost, shared_mm_time_cost))
if not torch.allclose(real_c, shared_c):
print("shared mm:", i, torch.max(real_c - shared_c))
start.record()
core.org_mm_shared_half(a_half, b_half, c_half)
end.record()
torch.cuda.synchronize()
print(f"org_mm_shared_half with i={i}: {start.elapsed_time(end)} ms")
shared_half_mm_time_cost = start.elapsed_time(end)
half_real_c = torch.matmul(a_half, b_half)
if not torch.allclose(half_real_c, c_half):
print("not equal, half real c:", i, torch.max(half_real_c - c_half))
import numpy as np
org_mean = np.mean([x[0] for x in time_cost_result])
share_mean = np.mean([x[1] for x in time_cost_result])
print("org mean result:", org_mean)
print("shared mean result:", share_mean)
print("time cost save:", (org_mean - share_mean) / org_mean)