torch_ext/tests/test_triton_mma.py
2025-03-27 03:44:28 +08:00

177 lines
4.4 KiB
Python

# coding=utf-8
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr):
xidx = tl.program_id(0)
index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = index < numel
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def add_mat_kernel(
a_ptr,
b_ptr,
c_ptr,
stride_m,
stride_n,
m,
n,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
midx = tl.program_id(0)
nidx = tl.program_id(1)
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def threed_mat_kernel(
a_ptr,
b_ptr,
c_ptr,
stride_1,
stride_m,
stride_n,
num_token,
m,
n,
TOKEN_BLOCK: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
token_idx = tl.program_id(0)
midx = tl.program_id(1)
nidx = tl.program_id(2)
# tl.device_print("token idx:", token_idx)
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
index = (
token_idx * stride_1
+ m_offset[:, None] * stride_m
+ n_offset[None, :] * stride_n
)
a = tl.load(a_ptr + index, mask=mask)
b = tl.load(b_ptr + index, mask=mask)
c = a + b
tl.store(c_ptr + index, c, mask=mask)
@triton.jit
def mma_kernel(
a_ptr,
b_ptr,
c_ptr,
m,
n,
k,
stride_am,
stride_an,
stride_bm,
stride_bk,
stride_cm,
stride_ck,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
midx = tl.program_id(0)
nidx = tl.program_id(1)
a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n)
a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an
a = tl.load(a_ptr + a_index, mask=a_mask)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
def test_add_kernel():
a = torch.randn(size=(1024,), device="cuda")
b = torch.randn(size=(1024,), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE = 32
grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE)
real_c = a + b
assert torch.allclose(real_c, c), "not equal"
print("all right")
def test_add_mat_kernel():
a = torch.randn(size=(127, 255), device="cuda")
b = torch.randn(size=(127, 255), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
grid = lambda meta: (
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]),
)
add_mat_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.size(0),
a.size(1),
BLOCK_SIZE_M,
BLOCK_SIZE_N,
)
real_c = a + b
assert torch.allclose(c, real_c), "not equal"
print("all right")
def test_three_dimension():
num_token = 128
a = torch.randn(size=(num_token, 127, 255), device="cuda")
b = torch.randn(size=(num_token, 127, 255), device="cuda")
c = torch.empty_like(a)
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
TOKEN_BLOCK = a.size(0)
grid = lambda meta: (
a.size(0),
triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]),
triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]),
)
threed_mat_kernel[grid](
a,
b,
c,
a.stride(0),
a.stride(1),
a.stride(2),
a.size(0),
a.size(1),
a.size(2),
TOKEN_BLOCK,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
)
real_c = a + b
assert torch.allclose(c, real_c), "not equal"
print("all right")
if __name__ == "__main__":
test_add_kernel()
test_add_mat_kernel()
test_three_dimension()