test triton, seems like very well.
This commit is contained in:
parent
58093d7a71
commit
c77f9602ea
36
csrc/quantize.cu
Normal file
36
csrc/quantize.cu
Normal file
@ -0,0 +1,36 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
__global__ void quantize(const half *src, __nv_fp8_storage_t *dest, int x_len, int y_len)
|
||||
{
|
||||
int x_start = threadIdx.x * blockDim.x;
|
||||
int y_start = threadIdx.y * blockDim.y;
|
||||
__shared__ half max_value;
|
||||
|
||||
max_value = __float2half(-10000.0f);
|
||||
for (int i = 0; i < blockDim.x; i++)
|
||||
{
|
||||
for (int j = 0; j < blockDim.x; j++)
|
||||
{
|
||||
if (x_start + i < x_len && y_start + j < y_len)
|
||||
{
|
||||
|
||||
int real_offset = (y_start + j) * x_len + x_start + i;
|
||||
max_value = __hmax(src[real_offset], max_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < blockDim.x; i++)
|
||||
{
|
||||
for (int j = 0; j < blockDim.y; j++)
|
||||
{
|
||||
if (x_start + i < x_len && y_start + j < y_len)
|
||||
{
|
||||
|
||||
int real_offset = (y_start + j) * x_len + x_start + i;
|
||||
half tmp = __hdiv(src[real_offset], max_value);
|
||||
dest[real_offset] = __nv_cvt_halfraw_to_fp8(__nv_half_raw(tmp), __NV_SATFINITE, __NV_E5M2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1
setup.py
1
setup.py
@ -11,6 +11,7 @@ files = [
|
||||
"csrc/core_bind.cpp",
|
||||
"csrc/max.cu",
|
||||
"csrc/md.cu",
|
||||
"csrc/quantize.cu",
|
||||
]
|
||||
extension = CUDAExtension(
|
||||
name="torch_cuda_ext.core",
|
||||
|
||||
468
tests/test_mma.py
Normal file
468
tests/test_mma.py
Normal file
@ -0,0 +1,468 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
|
||||
def is_cuda():
|
||||
return triton.runtime.driver.active.get_current_target().backend == "cuda"
|
||||
|
||||
|
||||
def is_hip_cdna2():
|
||||
target = triton.runtime.driver.active.get_current_target()
|
||||
return target.backend == "hip" and target.arch == "gfx90a"
|
||||
|
||||
|
||||
def get_cuda_autotune_config():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
# Good config for fp8 inputs.
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 8,
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_hip_autotune_config():
|
||||
return [
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 16,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"waves_per_eu": 2,
|
||||
},
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 16,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"waves_per_eu": 2,
|
||||
},
|
||||
num_warps=8,
|
||||
num_stages=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"waves_per_eu": 2,
|
||||
},
|
||||
num_warps=8,
|
||||
num_stages=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"waves_per_eu": 3,
|
||||
},
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 32,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"waves_per_eu": 8,
|
||||
},
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_autotune_config():
|
||||
if is_cuda():
|
||||
return get_cuda_autotune_config()
|
||||
else:
|
||||
return get_hip_autotune_config()
|
||||
|
||||
|
||||
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
|
||||
# - A list of `triton.Config` objects that define different configurations of
|
||||
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
|
||||
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
|
||||
# provided configs
|
||||
@triton.autotune(
|
||||
configs=get_autotune_config(),
|
||||
key=["M", "N", "K"],
|
||||
)
|
||||
@triton.jit
|
||||
def matmul_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
# Matrix dimensions
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# The stride variables represent how much to increase the ptr by when moving by 1
|
||||
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
|
||||
# by to get the element one row down (A has M rows).
|
||||
stride_am,
|
||||
stride_ak, #
|
||||
stride_bk,
|
||||
stride_bn, #
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr, #
|
||||
GROUP_SIZE_M: tl.constexpr, #
|
||||
ACTIVATION: tl.constexpr, #
|
||||
):
|
||||
"""Kernel for computing the matmul C = A x B.
|
||||
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
||||
"""
|
||||
# -----------------------------------------------------------
|
||||
# Map program ids `pid` to the block of C it should compute.
|
||||
# This is done in a grouped ordering to promote L2 data reuse.
|
||||
# See above `L2 Cache Optimizations` section for details.
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
|
||||
pid_n = (pid % num_pid_in_group) // group_size_m
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Create pointers for the first blocks of A and B.
|
||||
# We will advance this pointer as we move in the K direction
|
||||
# and accumulate
|
||||
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
|
||||
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
|
||||
# See above `Pointer Arithmetic` section for details
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Iterate to compute a block of the C matrix.
|
||||
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
||||
# of fp32 values for higher accuracy.
|
||||
# `accumulator` will be converted back to fp16 after the loop.
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
# Load the next block of A and B, generate a mask by checking the K dimension.
|
||||
# If it is out of bounds, set it to 0.
|
||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
|
||||
# We accumulate along the K dimension.
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
# Advance the ptrs to the next K block.
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
# You can fuse arbitrary activation functions here
|
||||
# while the accumulator is still in FP32!
|
||||
if ACTIVATION == "leaky_relu":
|
||||
accumulator = leaky_relu(accumulator)
|
||||
c = accumulator.to(tl.float16)
|
||||
|
||||
# -----------------------------------------------------------
|
||||
# Write back the block of the output matrix C with masks.
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
|
||||
@triton.jit
|
||||
def leaky_relu(x):
|
||||
return tl.where(x >= 0, x, 0.01 * x)
|
||||
|
||||
|
||||
def matmul(a, b, activation=""):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
assert a.is_contiguous(), "Matrix A must be contiguous"
|
||||
M, K = a.shape
|
||||
K, N = b.shape
|
||||
# Allocates output.
|
||||
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
|
||||
# 1D launch kernel where each block gets its own program.
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
matmul_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c, #
|
||||
M,
|
||||
N,
|
||||
K, #
|
||||
a.stride(0),
|
||||
a.stride(1), #
|
||||
b.stride(0),
|
||||
b.stride(1), #
|
||||
c.stride(0),
|
||||
c.stride(1), #
|
||||
ACTIVATION=activation, #
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
||||
triton_output = matmul(a, b)
|
||||
torch_output = torch.matmul(a, b)
|
||||
print(f"triton_output_with_fp16_inputs={triton_output}")
|
||||
print(f"torch_output_with_fp16_inputs={torch_output}")
|
||||
# Bigger tolerance for AMD CDNA2 devices.
|
||||
# CDNA2 devices use reduced precision fp16 and bf16 and flush input and
|
||||
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
|
||||
rtol = 1e-2 if is_hip_cdna2() else 0
|
||||
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
|
||||
if TORCH_HAS_FP8 and is_cuda():
|
||||
torch.manual_seed(0)
|
||||
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
||||
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
|
||||
a = a.to(torch.float8_e5m2)
|
||||
# pre-transpose b for efficiency.
|
||||
b = b.T
|
||||
b = b.to(torch.float8_e5m2)
|
||||
triton_output = matmul(a, b)
|
||||
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
|
||||
print(f"triton_output_with_fp8_inputs={triton_output}")
|
||||
print(f"torch_output_with_fp8_inputs={torch_output}")
|
||||
if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
|
||||
print("✅ Triton and Torch match")
|
||||
else:
|
||||
print("❌ Triton and Torch differ")
|
||||
|
||||
ref_lib = "cuBLAS" if is_cuda() else "rocBLAS"
|
||||
|
||||
configs = []
|
||||
for fp8_inputs in [False, True]:
|
||||
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
|
||||
continue
|
||||
configs.append(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot
|
||||
x_vals=[
|
||||
128 * i for i in range(2, 33)
|
||||
], # Different possible values for `x_name`
|
||||
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
|
||||
# Possible values for `line_arg`
|
||||
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
|
||||
line_vals=(
|
||||
["triton"] if fp8_inputs else [ref_lib.lower(), "triton"]
|
||||
), # Label name for the lines
|
||||
line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles
|
||||
styles=[("green", "-"), ("blue", "-")],
|
||||
ylabel="TFLOPS", # Label name for the y-axis
|
||||
plot_name="matmul-performance-"
|
||||
+ (
|
||||
"fp16" if not fp8_inputs else "fp8"
|
||||
), # Name for the plot, used also as a file name for saving the plot.
|
||||
args={"fp8_inputs": fp8_inputs},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
def benchmark(M, N, K, provider, fp8_inputs):
|
||||
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
|
||||
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
|
||||
if TORCH_HAS_FP8 and fp8_inputs:
|
||||
a = a.to(torch.float8_e5m2)
|
||||
b = b.T
|
||||
b = b.to(torch.float8_e5m2)
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
if provider == ref_lib.lower():
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: torch.matmul(a, b), quantiles=quantiles
|
||||
)
|
||||
if provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: matmul(a, b), quantiles=quantiles
|
||||
)
|
||||
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
|
||||
return perf(ms), perf(max_ms), perf(min_ms)
|
||||
|
||||
|
||||
benchmark.run(show_plots=True, print_data=True)
|
||||
27
tests/test_profille.py
Normal file
27
tests/test_profille.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
# 定义模型和优化器
|
||||
model = nn.Linear(100, 10).cuda()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# 启动 Profiler
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], # 监控 GPU 和 CPU
|
||||
record_shapes=True, # 记录张量形状
|
||||
profile_memory=True, # 分析内存使用
|
||||
with_stack=True, # 记录调用栈
|
||||
) as prof:
|
||||
for _ in range(10):
|
||||
x = torch.randn(64, 100).cuda()
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 输出分析结果
|
||||
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
||||
prof.export_chrome_trace("./trace.json")
|
||||
84
tests/test_triton.py
Normal file
84
tests/test_triton.py
Normal file
@ -0,0 +1,84 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_kernel(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
input_row_stride,
|
||||
output_row_stride,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
# 获取当前程序的行索引
|
||||
row_idx = tl.program_id(0)
|
||||
|
||||
# 计算输入和输出行的起始指针
|
||||
row_start_ptr = input_ptr + row_idx * input_row_stride
|
||||
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
||||
|
||||
# 将输入数据加载到本地内存
|
||||
row_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
input_ptrs = row_start_ptr + row_offsets
|
||||
row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf"))
|
||||
|
||||
# 计算 Softmax
|
||||
row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值
|
||||
numerator = tl.exp(row_minus_max)
|
||||
denominator = tl.sum(numerator, axis=0)
|
||||
softmax_output = numerator / denominator
|
||||
|
||||
# 将结果写回输出
|
||||
output_ptrs = output_row_start_ptr + row_offsets
|
||||
tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols)
|
||||
|
||||
|
||||
def softmax(x):
|
||||
n_rows, n_cols = x.shape
|
||||
|
||||
# 分配输出张量
|
||||
output = torch.empty_like(x)
|
||||
|
||||
# 定义 GPU 内核的网格和块大小
|
||||
BLOCK_SIZE = triton.next_power_of_2(n_cols)
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 2048:
|
||||
num_warps = 8
|
||||
if BLOCK_SIZE >= 4096:
|
||||
num_warps = 16
|
||||
|
||||
# 启动 Triton 内核
|
||||
softmax_kernel[(n_rows,)](
|
||||
output,
|
||||
x,
|
||||
x.stride(0),
|
||||
output.stride(0),
|
||||
n_cols,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# 测试 Softmax
|
||||
if __name__ == "__main__":
|
||||
# 创建一个随机矩阵
|
||||
x = torch.randn(4, 16, device="cuda")
|
||||
|
||||
# 使用 Triton 计算 Softmax
|
||||
output_triton = softmax(x)
|
||||
|
||||
# 使用 PyTorch 计算 Softmax 作为参考
|
||||
output_torch = torch.softmax(x, dim=1)
|
||||
|
||||
# 检查结果是否一致
|
||||
print("Input:")
|
||||
print(x)
|
||||
print("Triton Softmax:")
|
||||
print(output_triton)
|
||||
print("PyTorch Softmax:")
|
||||
print(output_torch)
|
||||
print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")
|
||||
176
tests/test_triton_mma.py
Normal file
176
tests/test_triton_mma.py
Normal file
@ -0,0 +1,176 @@
|
||||
# coding=utf-8
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
||||
xidx = tl.program_id(0)
|
||||
index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = index < numel
|
||||
a = tl.load(a_ptr + index, mask=mask)
|
||||
b = tl.load(b_ptr + index, mask=mask)
|
||||
c = a + b
|
||||
tl.store(c_ptr + index, c, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_mat_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
stride_m,
|
||||
stride_n,
|
||||
m,
|
||||
n,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
midx = tl.program_id(0)
|
||||
nidx = tl.program_id(1)
|
||||
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
||||
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
|
||||
a = tl.load(a_ptr + index, mask=mask)
|
||||
b = tl.load(b_ptr + index, mask=mask)
|
||||
c = a + b
|
||||
tl.store(c_ptr + index, c, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def threed_mat_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
stride_1,
|
||||
stride_m,
|
||||
stride_n,
|
||||
num_token,
|
||||
m,
|
||||
n,
|
||||
TOKEN_BLOCK: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
midx = tl.program_id(1)
|
||||
nidx = tl.program_id(2)
|
||||
# tl.device_print("token idx:", token_idx)
|
||||
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
||||
index = (
|
||||
token_idx * stride_1
|
||||
+ m_offset[:, None] * stride_m
|
||||
+ n_offset[None, :] * stride_n
|
||||
)
|
||||
a = tl.load(a_ptr + index, mask=mask)
|
||||
b = tl.load(b_ptr + index, mask=mask)
|
||||
c = a + b
|
||||
tl.store(c_ptr + index, c, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def mma_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
stride_am,
|
||||
stride_an,
|
||||
stride_bm,
|
||||
stride_bk,
|
||||
stride_cm,
|
||||
stride_ck,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
):
|
||||
midx = tl.program_id(0)
|
||||
nidx = tl.program_id(1)
|
||||
a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n)
|
||||
a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an
|
||||
a = tl.load(a_ptr + a_index, mask=a_mask)
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
|
||||
def test_add_kernel():
|
||||
a = torch.randn(size=(1024,), device="cuda")
|
||||
b = torch.randn(size=(1024,), device="cuda")
|
||||
c = torch.empty_like(a)
|
||||
BLOCK_SIZE = 32
|
||||
grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),)
|
||||
add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE)
|
||||
real_c = a + b
|
||||
assert torch.allclose(real_c, c), "not equal"
|
||||
print("all right")
|
||||
|
||||
|
||||
def test_add_mat_kernel():
|
||||
a = torch.randn(size=(127, 255), device="cuda")
|
||||
b = torch.randn(size=(127, 255), device="cuda")
|
||||
c = torch.empty_like(a)
|
||||
BLOCK_SIZE_M = 32
|
||||
BLOCK_SIZE_N = 16
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
add_mat_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
a.size(0),
|
||||
a.size(1),
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
)
|
||||
real_c = a + b
|
||||
assert torch.allclose(c, real_c), "not equal"
|
||||
print("all right")
|
||||
|
||||
|
||||
def test_three_dimension():
|
||||
num_token = 128
|
||||
a = torch.randn(size=(num_token, 127, 255), device="cuda")
|
||||
b = torch.randn(size=(num_token, 127, 255), device="cuda")
|
||||
c = torch.empty_like(a)
|
||||
BLOCK_SIZE_M = 32
|
||||
BLOCK_SIZE_N = 16
|
||||
TOKEN_BLOCK = a.size(0)
|
||||
grid = lambda meta: (
|
||||
a.size(0),
|
||||
triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]),
|
||||
)
|
||||
threed_mat_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
c,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
a.stride(2),
|
||||
a.size(0),
|
||||
a.size(1),
|
||||
a.size(2),
|
||||
TOKEN_BLOCK,
|
||||
BLOCK_SIZE_M,
|
||||
BLOCK_SIZE_N,
|
||||
)
|
||||
real_c = a + b
|
||||
assert torch.allclose(c, real_c), "not equal"
|
||||
print("all right")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_add_kernel()
|
||||
test_add_mat_kernel()
|
||||
test_three_dimension()
|
||||
Loading…
Reference in New Issue
Block a user