简单修改一下。
This commit is contained in:
parent
374cd36597
commit
93b10bb894
151
test_triton.py
151
test_triton.py
@ -1,151 +0,0 @@
|
||||
# coding=utf-8
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
configs = [
|
||||
# 针对小规模矩阵的配置(例如 M,N,K < 1024)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads)
|
||||
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
|
||||
),
|
||||
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_warps=8,
|
||||
num_stages=4,
|
||||
),
|
||||
# 针对大规模矩阵(例如 M,N,K ≥ 4096)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_warps=16,
|
||||
num_stages=5,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
|
||||
num_warps=16,
|
||||
num_stages=5,
|
||||
),
|
||||
# 添加更多参数组合...
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(configs=configs, key=["m", "n", "k"])
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_an,
|
||||
stride_bm,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
mpid = tl.program_id(0)
|
||||
npid = tl.program_id(1)
|
||||
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
for k_pad in range(0, k, BLOCK_SIZE_K):
|
||||
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
|
||||
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
|
||||
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
|
||||
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
|
||||
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
|
||||
|
||||
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
|
||||
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
|
||||
acc += tl.dot(a_val, b_val)
|
||||
acc.to(tl.bfloat16)
|
||||
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
|
||||
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
|
||||
tl.store(c_ptr, acc, mask=c_mask)
|
||||
|
||||
|
||||
def triton_matmul(a, b):
|
||||
# Define the input tensors
|
||||
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Define the grid size
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
# Launch the kernel
|
||||
matmul_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
m=a.size(0),
|
||||
n=b.size(1),
|
||||
k=a.size(1),
|
||||
stride_am=a.stride(0),
|
||||
stride_an=a.stride(1),
|
||||
stride_bm=b.stride(0),
|
||||
stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0),
|
||||
stride_cn=c.stride(1),
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def benchmark(fn, *args, **kwargs):
|
||||
# 预热 GPU(避免首次运行因初始化影响时间)
|
||||
for _ in range(10):
|
||||
fn(*args, **kwargs)
|
||||
|
||||
# 创建 CUDA 事件计时
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(100): # 多次运行取平均
|
||||
fn(*args, **kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
|
||||
|
||||
|
||||
def benchmark_triton():
|
||||
# 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul
|
||||
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 测速
|
||||
time_ms = benchmark(triton_matmul, a, b)
|
||||
print(f"Triton 耗时: {time_ms:.3f} ms")
|
||||
|
||||
|
||||
def benchmark_torch():
|
||||
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
||||
M, N, K = 4096 * 4, 4096 * 2, 4096 * 2 # 矩阵尺寸
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 测速
|
||||
time_ms = benchmark(torch.matmul, a, b)
|
||||
print(f"PyTorch 耗时: {time_ms:.3f} ms")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_triton()
|
||||
benchmark_torch()
|
||||
@ -1,84 +1,169 @@
|
||||
import torch
|
||||
# coding=utf-8
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
logger.add("triton.log", rotation="1 MB", retention="10 days", level="DEBUG")
|
||||
configs = [
|
||||
# 针对小规模矩阵的配置(例如 M,N,K < 1024)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_warps=4, # 每个线程块的 warp 数量(4 warps = 128 threads)
|
||||
num_stages=3, # 流水线阶段数(对计算密集型操作更友好)
|
||||
),
|
||||
# 针对中等规模矩阵(例如 1024 ≤ M,N,K < 4096)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_warps=8,
|
||||
num_stages=4,
|
||||
),
|
||||
# 针对大规模矩阵(例如 M,N,K ≥ 4096)
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_warps=16,
|
||||
num_stages=5,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64},
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128},
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128, # 高维度分块
|
||||
"BLOCK_SIZE_N": 256, # 宽维度分块(充分利用列方向并行性)
|
||||
"BLOCK_SIZE_K": 64, # 增大 K 维度分块,提升计算/内存比
|
||||
},
|
||||
num_warps=8, # 每个线程块包含 8 warps(256 线程)
|
||||
num_stages=3, # 流水线阶段数(平衡寄存器压力和延迟隐藏)
|
||||
# 添加更多参数组合...
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@triton.autotune(configs=configs, key=["m", "n", "k"])
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
def matmul_kernel(
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_an,
|
||||
stride_bm,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
# 获取当前程序的行索引
|
||||
row_idx = tl.program_id(0)
|
||||
mpid = tl.program_id(0)
|
||||
npid = tl.program_id(1)
|
||||
block_m = mpid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
block_n = npid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# 计算输入和输出行的起始指针
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
for k_pad in range(0, k, BLOCK_SIZE_K):
|
||||
k_range = tl.arange(0, BLOCK_SIZE_K) + k_pad
|
||||
a_ptr = a + (block_m[:, None] * stride_am + k_range[None, :] * stride_an)
|
||||
b_ptr = b + (k_range[:, None] * stride_bm + block_n[None, :] * stride_bn)
|
||||
a_mask = (block_m[:, None] < m) & (k_range[None, :] < k)
|
||||
b_mask = (k_range[:, None] < k) & (block_n[None, :] < n)
|
||||
|
||||
# 将输入数据加载到本地内存
|
||||
row_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + row_offsets
|
||||
row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf"))
|
||||
|
||||
# 计算 Softmax
|
||||
row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
# 将结果写回输出
|
||||
output_ptrs = output_row_start_ptr + row_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols)
|
||||
a_val = tl.load(a_ptr, mask=a_mask, other=0.0)
|
||||
b_val = tl.load(b_ptr, mask=b_mask, other=0.0)
|
||||
acc += tl.dot(a_val, b_val)
|
||||
acc.to(tl.bfloat16)
|
||||
c_ptr = c + (block_m[:, None] * stride_cm + block_n[None, :] * stride_cn)
|
||||
c_mask = (block_m[:, None] < m) & (block_n[None, :] < n)
|
||||
tl.store(c_ptr, acc, mask=c_mask)
|
||||
|
||||
|
||||
def softmax(x):
|
||||
n_rows, n_cols = x.shape
|
||||
def triton_matmul(a, b):
|
||||
# Define the input tensors
|
||||
c = torch.empty(a.size(0), b.size(1), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 分配输出张量
|
||||
output = torch.empty_like(x)
|
||||
|
||||
# 定义 GPU 内核的网格和块大小
|
||||
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
|
||||
# 启动 Triton 内核
|
||||
softmax_kernel[(n_rows,)](
|
||||
output,
|
||||
x,
|
||||
x.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
# Define the grid size
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(b.size(1), meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
return output
|
||||
# Launch the kernel
|
||||
matmul_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
m=a.size(0),
|
||||
n=b.size(1),
|
||||
k=a.size(1),
|
||||
stride_am=a.stride(0),
|
||||
stride_an=a.stride(1),
|
||||
stride_bm=b.stride(0),
|
||||
stride_bn=b.stride(1),
|
||||
stride_cm=c.stride(0),
|
||||
stride_cn=c.stride(1),
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def benchmark(fn, *args, **kwargs):
|
||||
# 预热 GPU(避免首次运行因初始化影响时间)
|
||||
for _ in range(10):
|
||||
fn(*args, **kwargs)
|
||||
|
||||
# 创建 CUDA 事件计时
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
start_event.record()
|
||||
for _ in range(100): # 多次运行取平均
|
||||
fn(*args, **kwargs)
|
||||
end_event.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return start_event.elapsed_time(end_event) / 100 # 毫秒/次
|
||||
|
||||
|
||||
def benchmark_triton(M, N, K):
|
||||
# 假设你的 Triton 核函数名为 matmul_kernel,调用方法为 triton_matmul
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 测速
|
||||
time_ms = benchmark(triton_matmul, a, b)
|
||||
print(f"Triton 耗时: {time_ms:.3f} ms")
|
||||
return time_ms
|
||||
|
||||
|
||||
def benchmark_torch(M, N, K):
|
||||
# 假设你的 PyTorch 矩阵乘法函数名为 torch.matmul
|
||||
a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
|
||||
b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# 测速
|
||||
time_ms = benchmark(torch.matmul, a, b)
|
||||
print(f"PyTorch 耗时: {time_ms:.3f} ms")
|
||||
return time_ms
|
||||
|
||||
|
||||
# 测试 Softmax
|
||||
if __name__ == "__main__":
|
||||
# 创建一个随机矩阵
|
||||
x = torch.randn(4, 16, device="cuda")
|
||||
|
||||
# 使用 Triton 计算 Softmax
|
||||
output_triton = softmax(x)
|
||||
|
||||
# 使用 PyTorch 计算 Softmax 作为参考
|
||||
output_torch = torch.softmax(x, dim=1)
|
||||
|
||||
# 检查结果是否一致
|
||||
print("Input:")
|
||||
print(x)
|
||||
print("Triton Softmax:")
|
||||
print(output_triton)
|
||||
print("PyTorch Softmax:")
|
||||
print(output_torch)
|
||||
print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")
|
||||
for M in range(1024, 4096 * 4, 1024):
|
||||
for N in range(1024, 4096 * 2, 1024):
|
||||
for K in range(1024, 4096 * 2, 1024):
|
||||
triton_cost_time = benchmark_triton(M, N, K)
|
||||
torch_cost_time = benchmark_torch(M, N, K)
|
||||
logger.info(
|
||||
f"{M}, {N}, {K} triton/torch rate: {triton_cost_time / torch_cost_time}",
|
||||
)
|
||||
|
||||
57
tests/test_triton_softmax.py
Normal file
57
tests/test_triton_softmax.py
Normal file
@ -0,0 +1,57 @@
|
||||
# coding=utf-8
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(a_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE: tl.constexpr):
|
||||
# 计算当前线程块的起始位置
|
||||
midx = tl.program_id(0)
|
||||
nidx = tl.program_id(1)
|
||||
m_offset = midx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
n_offset = nidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
# 创建掩码,确保不会越界访问
|
||||
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
||||
|
||||
# 计算索引
|
||||
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
|
||||
|
||||
# 加载输入数据
|
||||
a = tl.load(a_ptr + index, mask=mask)
|
||||
|
||||
# 计算 softmax
|
||||
max_a = tl.max(a, axis=1, keepdims=True)
|
||||
exp_a = tl.exp(a - max_a)
|
||||
sum_exp_a = tl.sum(exp_a, axis=1, keepdims=True)
|
||||
softmax_a = exp_a / sum_exp_a
|
||||
|
||||
# 存储结果
|
||||
tl.store(c_ptr + index, softmax_a, mask=mask)
|
||||
|
||||
|
||||
def softmax(a, dim=0):
|
||||
"""
|
||||
Perform softmax operation on the input tensor `a` along the specified dimension `dim`.
|
||||
This function uses Triton to accelerate the computation.
|
||||
"""
|
||||
m, n = a.shape
|
||||
stride_m = a.stride(0)
|
||||
stride_n = a.stride(1)
|
||||
|
||||
# Allocate output tensor
|
||||
c = torch.empty_like(a)
|
||||
|
||||
# Launch Triton kernel
|
||||
all_other_dim = 1
|
||||
for i in range(len(a.size())):
|
||||
if i != dim:
|
||||
all_other_dim *= a.size(i)
|
||||
|
||||
BLOCK_SIZE = a.size(dim)
|
||||
grid = (all_other_dim,)
|
||||
# FIXME: The grid size should be adjusted based on the number of dimensions
|
||||
softmax_kernel[grid](a, c, stride_m, stride_n, m, n, BLOCK_SIZE=BLOCK_SIZE)
|
||||
|
||||
return c
|
||||
Loading…
Reference in New Issue
Block a user