88 lines
2.9 KiB
Python
88 lines
2.9 KiB
Python
# coding=utf-8
|
|
|
|
# import torch
|
|
# from torch_cuda_ext import core
|
|
|
|
# for i in range(1, 100):
|
|
# a = torch.randn(size=(1024, i)).float().cuda()
|
|
# b = torch.randn(size=(i, 1024)).float().cuda()
|
|
# c = torch.empty(size=(1024, 1024)).float().cuda()
|
|
# shared_c = torch.empty(size=(1024, 1024)).float().cuda()
|
|
|
|
# core.org_mm(a, b, c)
|
|
|
|
# real_c = torch.matmul(a, b)
|
|
# if not torch.allclose(real_c, c):
|
|
# print(i, torch.max(real_c - c))
|
|
|
|
# core.org_mm_shared(a, b, shared_c)
|
|
# if not torch.allclose(real_c, shared_c):
|
|
# print("shared mm:", i, torch.max(real_c - shared_c))
|
|
|
|
# torch.cuda.Event
|
|
|
|
import torch
|
|
from torch_cuda_ext import core
|
|
|
|
# 初始化CUDA事件
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
time_cost_result = []
|
|
for i in range(1, 1024):
|
|
a = torch.randn(size=(1024, i)).float().cuda()
|
|
b = torch.randn(size=(i, 1024)).float().cuda()
|
|
c = torch.empty(size=(1024, 1024)).float().cuda()
|
|
a_half = torch.randn(size=(1024, i)).half().cuda()
|
|
b_half = torch.randn(size=(i, 1024)).half().cuda()
|
|
c_half = torch.empty(size=(1024, 1024)).half().cuda()
|
|
|
|
shared_c = torch.empty(size=(1024, 1024)).float().cuda()
|
|
|
|
# 记录core.org_mm开始时间
|
|
start.record()
|
|
core.org_mm(a, b, c)
|
|
# 记录core.org_mm结束时间并计算耗时
|
|
end.record()
|
|
torch.cuda.synchronize() # 等待事件记录完成
|
|
print(f"org_mm with i={i}: {start.elapsed_time(end)} ms")
|
|
org_mm_time_cost = start.elapsed_time(end)
|
|
start.record()
|
|
real_c = torch.matmul(a, b)
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
print(f"real_mm with i={i}: {start.elapsed_time(end)} ms")
|
|
if not torch.allclose(real_c, c):
|
|
print(i, torch.max(real_c - c))
|
|
|
|
# 记录core.org_mm_shared开始时间
|
|
start.record()
|
|
core.org_mm_shared(a, b, shared_c)
|
|
# 记录core.org_mm_shared结束时间并计算耗时
|
|
end.record()
|
|
torch.cuda.synchronize() # 等待事件记录完成
|
|
print(f"org_mm_shared with i={i}: {start.elapsed_time(end)} ms")
|
|
shared_mm_time_cost = start.elapsed_time(end)
|
|
time_cost_result.append((org_mm_time_cost, shared_mm_time_cost))
|
|
|
|
if not torch.allclose(real_c, shared_c):
|
|
print("shared mm:", i, torch.max(real_c - shared_c))
|
|
|
|
start.record()
|
|
core.org_mm_shared_half(a_half, b_half, c_half)
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
print(f"org_mm_shared_half with i={i}: {start.elapsed_time(end)} ms")
|
|
shared_half_mm_time_cost = start.elapsed_time(end)
|
|
half_real_c = torch.matmul(a_half, b_half)
|
|
if not torch.allclose(half_real_c, c_half):
|
|
print("not equal, half real c:", i, torch.max(half_real_c - c_half))
|
|
|
|
|
|
import numpy as np
|
|
|
|
org_mean = np.mean([x[0] for x in time_cost_result])
|
|
share_mean = np.mean([x[1] for x in time_cost_result])
|
|
print("org mean result:", org_mean)
|
|
print("shared mean result:", share_mean)
|
|
print("time cost save:", (org_mean - share_mean) / org_mean)
|