# coding=utf-8 # import torch # from torch_cuda_ext import core # for i in range(1, 100): # a = torch.randn(size=(1024, i)).float().cuda() # b = torch.randn(size=(i, 1024)).float().cuda() # c = torch.empty(size=(1024, 1024)).float().cuda() # shared_c = torch.empty(size=(1024, 1024)).float().cuda() # core.org_mm(a, b, c) # real_c = torch.matmul(a, b) # if not torch.allclose(real_c, c): # print(i, torch.max(real_c - c)) # core.org_mm_shared(a, b, shared_c) # if not torch.allclose(real_c, shared_c): # print("shared mm:", i, torch.max(real_c - shared_c)) # torch.cuda.Event import torch from torch_cuda_ext import core # 初始化CUDA事件 start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) time_cost_result = [] for i in range(1, 1024): a = torch.randn(size=(1024, i)).float().cuda() b = torch.randn(size=(i, 1024)).float().cuda() c = torch.empty(size=(1024, 1024)).float().cuda() a_half = torch.randn(size=(1024, i)).half().cuda() b_half = torch.randn(size=(i, 1024)).half().cuda() c_half = torch.empty(size=(1024, 1024)).half().cuda() shared_c = torch.empty(size=(1024, 1024)).float().cuda() # 记录core.org_mm开始时间 start.record() core.org_mm(a, b, c) # 记录core.org_mm结束时间并计算耗时 end.record() torch.cuda.synchronize() # 等待事件记录完成 print(f"org_mm with i={i}: {start.elapsed_time(end)} ms") org_mm_time_cost = start.elapsed_time(end) start.record() real_c = torch.matmul(a, b) end.record() torch.cuda.synchronize() print(f"real_mm with i={i}: {start.elapsed_time(end)} ms") if not torch.allclose(real_c, c): print(i, torch.max(real_c - c)) # 记录core.org_mm_shared开始时间 start.record() core.org_mm_shared(a, b, shared_c) # 记录core.org_mm_shared结束时间并计算耗时 end.record() torch.cuda.synchronize() # 等待事件记录完成 print(f"org_mm_shared with i={i}: {start.elapsed_time(end)} ms") shared_mm_time_cost = start.elapsed_time(end) time_cost_result.append((org_mm_time_cost, shared_mm_time_cost)) if not torch.allclose(real_c, shared_c): print("shared mm:", i, torch.max(real_c - shared_c)) start.record() core.org_mm_shared_half(a_half, b_half, c_half) end.record() torch.cuda.synchronize() print(f"org_mm_shared_half with i={i}: {start.elapsed_time(end)} ms") shared_half_mm_time_cost = start.elapsed_time(end) half_real_c = torch.matmul(a_half, b_half) if not torch.allclose(half_real_c, c_half): print("not equal, half real c:", i, torch.max(half_real_c - c_half)) import numpy as np org_mean = np.mean([x[0] for x in time_cost_result]) share_mean = np.mean([x[1] for x in time_cost_result]) print("org mean result:", org_mean) print("shared mean result:", share_mean) print("time cost save:", (org_mean - share_mean) / org_mean)