177 lines
4.4 KiB
Python
177 lines
4.4 KiB
Python
# coding=utf-8
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr):
|
|
xidx = tl.program_id(0)
|
|
index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
mask = index < numel
|
|
a = tl.load(a_ptr + index, mask=mask)
|
|
b = tl.load(b_ptr + index, mask=mask)
|
|
c = a + b
|
|
tl.store(c_ptr + index, c, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def add_mat_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
stride_m,
|
|
stride_n,
|
|
m,
|
|
n,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
):
|
|
midx = tl.program_id(0)
|
|
nidx = tl.program_id(1)
|
|
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
|
index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n
|
|
a = tl.load(a_ptr + index, mask=mask)
|
|
b = tl.load(b_ptr + index, mask=mask)
|
|
c = a + b
|
|
tl.store(c_ptr + index, c, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def threed_mat_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
stride_1,
|
|
stride_m,
|
|
stride_n,
|
|
num_token,
|
|
m,
|
|
n,
|
|
TOKEN_BLOCK: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
):
|
|
token_idx = tl.program_id(0)
|
|
midx = tl.program_id(1)
|
|
nidx = tl.program_id(2)
|
|
# tl.device_print("token idx:", token_idx)
|
|
m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
mask = (m_offset[:, None] < m) & (n_offset[None, :] < n)
|
|
index = (
|
|
token_idx * stride_1
|
|
+ m_offset[:, None] * stride_m
|
|
+ n_offset[None, :] * stride_n
|
|
)
|
|
a = tl.load(a_ptr + index, mask=mask)
|
|
b = tl.load(b_ptr + index, mask=mask)
|
|
c = a + b
|
|
tl.store(c_ptr + index, c, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def mma_kernel(
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
m,
|
|
n,
|
|
k,
|
|
stride_am,
|
|
stride_an,
|
|
stride_bm,
|
|
stride_bk,
|
|
stride_cm,
|
|
stride_ck,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
):
|
|
midx = tl.program_id(0)
|
|
nidx = tl.program_id(1)
|
|
a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n)
|
|
a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an
|
|
a = tl.load(a_ptr + a_index, mask=a_mask)
|
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
|
|
def test_add_kernel():
|
|
a = torch.randn(size=(1024,), device="cuda")
|
|
b = torch.randn(size=(1024,), device="cuda")
|
|
c = torch.empty_like(a)
|
|
BLOCK_SIZE = 32
|
|
grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),)
|
|
add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE)
|
|
real_c = a + b
|
|
assert torch.allclose(real_c, c), "not equal"
|
|
print("all right")
|
|
|
|
|
|
def test_add_mat_kernel():
|
|
a = torch.randn(size=(127, 255), device="cuda")
|
|
b = torch.randn(size=(127, 255), device="cuda")
|
|
c = torch.empty_like(a)
|
|
BLOCK_SIZE_M = 32
|
|
BLOCK_SIZE_N = 16
|
|
grid = lambda meta: (
|
|
triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]),
|
|
triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]),
|
|
)
|
|
add_mat_kernel[grid](
|
|
a,
|
|
b,
|
|
c,
|
|
a.stride(0),
|
|
a.stride(1),
|
|
a.size(0),
|
|
a.size(1),
|
|
BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N,
|
|
)
|
|
real_c = a + b
|
|
assert torch.allclose(c, real_c), "not equal"
|
|
print("all right")
|
|
|
|
|
|
def test_three_dimension():
|
|
num_token = 128
|
|
a = torch.randn(size=(num_token, 127, 255), device="cuda")
|
|
b = torch.randn(size=(num_token, 127, 255), device="cuda")
|
|
c = torch.empty_like(a)
|
|
BLOCK_SIZE_M = 32
|
|
BLOCK_SIZE_N = 16
|
|
TOKEN_BLOCK = a.size(0)
|
|
grid = lambda meta: (
|
|
a.size(0),
|
|
triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]),
|
|
triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]),
|
|
)
|
|
threed_mat_kernel[grid](
|
|
a,
|
|
b,
|
|
c,
|
|
a.stride(0),
|
|
a.stride(1),
|
|
a.stride(2),
|
|
a.size(0),
|
|
a.size(1),
|
|
a.size(2),
|
|
TOKEN_BLOCK,
|
|
BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N,
|
|
)
|
|
real_c = a + b
|
|
assert torch.allclose(c, real_c), "not equal"
|
|
print("all right")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_add_kernel()
|
|
test_add_mat_kernel()
|
|
test_three_dimension()
|