test triton, seems like very well.

This commit is contained in:
longfei li 2025-03-27 03:44:28 +08:00
parent 58093d7a71
commit c77f9602ea
6 changed files with 792 additions and 0 deletions

36
csrc/quantize.cu Normal file
View File

@ -0,0 +1,36 @@
#include <cuda_fp16.h>
#include <cuda_fp8.h>
__global__ void quantize(const half *src, __nv_fp8_storage_t *dest, int x_len, int y_len)
{
int x_start = threadIdx.x * blockDim.x;
int y_start = threadIdx.y * blockDim.y;
__shared__ half max_value;
max_value = __float2half(-10000.0f);
for (int i = 0; i < blockDim.x; i++)
{
for (int j = 0; j < blockDim.x; j++)
{
if (x_start + i < x_len && y_start + j < y_len)
{
int real_offset = (y_start + j) * x_len + x_start + i;
max_value = __hmax(src[real_offset], max_value);
}
}
}
for (int i = 0; i < blockDim.x; i++)
{
for (int j = 0; j < blockDim.y; j++)
{
if (x_start + i < x_len && y_start + j < y_len)
{
int real_offset = (y_start + j) * x_len + x_start + i;
half tmp = __hdiv(src[real_offset], max_value);
dest[real_offset] = __nv_cvt_halfraw_to_fp8(__nv_half_raw(tmp), __NV_SATFINITE, __NV_E5M2);
}
}
}
}

View File

@ -11,6 +11,7 @@ files = [
"csrc/core_bind.cpp", "csrc/core_bind.cpp",
"csrc/max.cu", "csrc/max.cu",
"csrc/md.cu", "csrc/md.cu",
"csrc/quantize.cu",
] ]
extension = CUDAExtension( extension = CUDAExtension(
name="torch_cuda_ext.core", name="torch_cuda_ext.core",

468
tests/test_mma.py Normal file
View File

@ -0,0 +1,468 @@
import torch
import triton
import triton.language as tl
DEVICE = "cuda"
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def is_hip_cdna2():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "hip" and target.arch == "gfx90a"
def get_cuda_autotune_config():
return [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
# Good config for fp8 inputs.
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=3,
num_warps=8,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
]
def get_hip_autotune_config():
return [
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 16,
"GROUP_SIZE_M": 1,
"waves_per_eu": 2,
},
num_warps=4,
num_stages=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 16,
"GROUP_SIZE_M": 4,
"waves_per_eu": 2,
},
num_warps=8,
num_stages=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"waves_per_eu": 2,
},
num_warps=8,
num_stages=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"waves_per_eu": 3,
},
num_warps=4,
num_stages=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"waves_per_eu": 8,
},
num_warps=4,
num_stages=2,
),
]
def get_autotune_config():
if is_cuda():
return get_cuda_autotune_config()
else:
return get_hip_autotune_config()
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=get_autotune_config(),
key=["M", "N", "K"],
)
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions
M,
N,
K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am,
stride_ak, #
stride_bk,
stride_bn, #
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
ACTIVATION: tl.constexpr, #
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetic` section for details
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
# -----------------------------------------------------------
# Write back the block of the output matrix C with masks.
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`.
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
matmul_kernel[grid](
a,
b,
c, #
M,
N,
K, #
a.stride(0),
a.stride(1), #
b.stride(0),
b.stride(1), #
c.stride(0),
c.stride(1), #
ACTIVATION=activation, #
)
return c
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
# Bigger tolerance for AMD CDNA2 devices.
# CDNA2 devices use reduced precision fp16 and bf16 and flush input and
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
rtol = 1e-2 if is_hip_cdna2() else 0
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2")
if TORCH_HAS_FP8 and is_cuda():
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16)
a = a.to(torch.float8_e5m2)
# pre-transpose b for efficiency.
b = b.T
b = b.to(torch.float8_e5m2)
triton_output = matmul(a, b)
torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16))
print(f"triton_output_with_fp8_inputs={triton_output}")
print(f"torch_output_with_fp8_inputs={torch_output}")
if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
ref_lib = "cuBLAS" if is_cuda() else "rocBLAS"
configs = []
for fp8_inputs in [False, True]:
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()):
continue
configs.append(
triton.testing.Benchmark(
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
], # Different possible values for `x_name`
line_arg="provider", # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment.
line_vals=(
["triton"] if fp8_inputs else [ref_lib.lower(), "triton"]
), # Label name for the lines
line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="matmul-performance-"
+ (
"fp16" if not fp8_inputs else "fp8"
), # Name for the plot, used also as a file name for saving the plot.
args={"fp8_inputs": fp8_inputs},
)
)
@triton.testing.perf_report(configs)
def benchmark(M, N, K, provider, fp8_inputs):
a = torch.randn((M, K), device=DEVICE, dtype=torch.float16)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float16)
if TORCH_HAS_FP8 and fp8_inputs:
a = a.to(torch.float8_e5m2)
b = b.T
b = b.to(torch.float8_e5m2)
quantiles = [0.5, 0.2, 0.8]
if provider == ref_lib.lower():
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a, b), quantiles=quantiles
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: matmul(a, b), quantiles=quantiles
)
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True)

27
tests/test_profille.py Normal file
View File

@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.profiler import profile, record_function, ProfilerActivity
# 定义模型和优化器
model = nn.Linear(100, 10).cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 启动 Profiler
with profile(
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], # 监控 GPU 和 CPU
record_shapes=True, # 记录张量形状
profile_memory=True, # 分析内存使用
with_stack=True, # 记录调用栈
) as prof:
for _ in range(10):
x = torch.randn(64, 100).cuda()
y = model(x)
loss = y.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 输出分析结果
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
prof.export_chrome_trace("./trace.json")

84
tests/test_triton.py Normal file
View File

@ -0,0 +1,84 @@
import torch
import triton
import triton.language as tl
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# 获取当前程序的行索引
row_idx = tl.program_id(0)
# 计算输入和输出行的起始指针
row_start_ptr = input_ptr + row_idx * input_row_stride
output_row_start_ptr = output_ptr + row_idx * output_row_stride
# 将输入数据加载到本地内存
row_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + row_offsets
row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf"))
# 计算 Softmax
row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 将结果写回输出
output_ptrs = output_row_start_ptr + row_offsets
tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols)
def softmax(x):
n_rows, n_cols = x.shape
# 分配输出张量
output = torch.empty_like(x)
# 定义 GPU 内核的网格和块大小
BLOCK_SIZE = triton.next_power_of_2(n_cols)
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# 启动 Triton 内核
softmax_kernel[(n_rows,)](
output,
x,
x.stride(0),
output.stride(0),
n_cols,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return output
# 测试 Softmax
if __name__ == "__main__":
# 创建一个随机矩阵
x = torch.randn(4, 16, device="cuda")
# 使用 Triton 计算 Softmax
output_triton = softmax(x)
# 使用 PyTorch 计算 Softmax 作为参考
output_torch = torch.softmax(x, dim=1)
# 检查结果是否一致
print("Input:")
print(x)
print("Triton Softmax:")
print(output_triton)
print("PyTorch Softmax:")
print(output_torch)
print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}")

176
tests/test_triton_mma.py Normal file
View File

@ -0,0 +1,176 @@
# coding=utf-8
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr):
xidx = tl.program_id(0)
index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = index < numel
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def add_mat_kernel(
a_ptr,
b_ptr,
c_ptr,
stride_m,
stride_n,
m,
n,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
midx = tl.program_id(0)
nidx = tl.program_id(1)
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def threed_mat_kernel(
a_ptr,
b_ptr,
c_ptr,
stride_1,
stride_m,
stride_n,
num_token,
m,
n,
TOKEN_BLOCK: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
token_idx = tl.program_id(0)
midx = tl.program_id(1)
nidx = tl.program_id(2)
# tl.device_print("token idx:", token_idx)
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
index = (
token_idx * stride_1
+ m_offset[:, None] * stride_m
+ n_offset[None, :] * stride_n
)
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def mma_kernel(
a_ptr,
b_ptr,
c_ptr,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bk,
stride_cm,
stride_ck,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
midx = tl.program_id(0)
nidx = tl.program_id(1)
a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n)
a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an
a = tl.load(a_ptr + a_index, mask=a_mask)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
def test_add_kernel():
a = torch.randn(size=(1024,), device="cuda")
b = torch.randn(size=(1024,), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE = 32
grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE)
real_c = a + b
assert torch.allclose(real_c, c), "not equal"
print("all right")
def test_add_mat_kernel():
a = torch.randn(size=(127, 255), device="cuda")
b = torch.randn(size=(127, 255), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]),
)
add_mat_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.size(0),
a.size(1),
BLOCK_SIZE_M,
BLOCK_SIZE_N,
)
real_c = a + b
assert torch.allclose(c, real_c), "not equal"
print("all right")
def test_three_dimension():
num_token = 128
a = torch.randn(size=(num_token, 127, 255), device="cuda")
b = torch.randn(size=(num_token, 127, 255), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
TOKEN_BLOCK = a.size(0)
grid = lambda meta: (
a.size(0),
triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]),
triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]),
)
threed_mat_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.stride(2),
a.size(0),
a.size(1),
a.size(2),
TOKEN_BLOCK,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
)
real_c = a + b
assert torch.allclose(c, real_c), "not equal"
print("all right")
if __name__ == "__main__":
test_add_kernel()
test_add_mat_kernel()
test_three_dimension()