# coding=utf-8 import torch import triton import triton.language as tl @triton.jit def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr): xidx = tl.program_id(0) index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = index < numel a = tl.load(a_ptr + index, mask=mask) b = tl.load(b_ptr + index, mask=mask) c = a + b tl.store(c_ptr + index, c, mask=mask) @triton.jit def add_mat_kernel( a_ptr, b_ptr, c_ptr, stride_m, stride_n, m, n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): midx = tl.program_id(0) nidx = tl.program_id(1) m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n a = tl.load(a_ptr + index, mask=mask) b = tl.load(b_ptr + index, mask=mask) c = a + b tl.store(c_ptr + index, c, mask=mask) @triton.jit def threed_mat_kernel( a_ptr, b_ptr, c_ptr, stride_1, stride_m, stride_n, num_token, m, n, TOKEN_BLOCK: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): token_idx = tl.program_id(0) midx = tl.program_id(1) nidx = tl.program_id(2) # tl.device_print("token idx:", token_idx) m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) index = ( token_idx * stride_1 + m_offset[:, None] * stride_m + n_offset[None, :] * stride_n ) a = tl.load(a_ptr + index, mask=mask) b = tl.load(b_ptr + index, mask=mask) c = a + b tl.store(c_ptr + index, c, mask=mask) @triton.jit def mma_kernel( a_ptr, b_ptr, c_ptr, m, n, k, stride_am, stride_an, stride_bm, stride_bk, stride_cm, stride_ck, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): midx = tl.program_id(0) nidx = tl.program_id(1) a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n) a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an a = tl.load(a_ptr + a_index, mask=a_mask) acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) def test_add_kernel(): a = torch.randn(size=(1024,), device="cuda") b = torch.randn(size=(1024,), device="cuda") c = torch.empty_like(a) BLOCK_SIZE = 32 grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),) add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE) real_c = a + b assert torch.allclose(real_c, c), "not equal" print("all right") def test_add_mat_kernel(): a = torch.randn(size=(127, 255), device="cuda") b = torch.randn(size=(127, 255), device="cuda") c = torch.empty_like(a) BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 16 grid = lambda meta: ( triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]), ) add_mat_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.size(0), a.size(1), BLOCK_SIZE_M, BLOCK_SIZE_N, ) real_c = a + b assert torch.allclose(c, real_c), "not equal" print("all right") def test_three_dimension(): num_token = 128 a = torch.randn(size=(num_token, 127, 255), device="cuda") b = torch.randn(size=(num_token, 127, 255), device="cuda") c = torch.empty_like(a) BLOCK_SIZE_M = 32 BLOCK_SIZE_N = 16 TOKEN_BLOCK = a.size(0) grid = lambda meta: ( a.size(0), triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]), triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]), ) threed_mat_kernel[grid]( a, b, c, a.stride(0), a.stride(1), a.stride(2), a.size(0), a.size(1), a.size(2), TOKEN_BLOCK, BLOCK_SIZE_M, BLOCK_SIZE_N, ) real_c = a + b assert torch.allclose(c, real_c), "not equal" print("all right") if __name__ == "__main__": test_add_kernel() test_add_mat_kernel() test_three_dimension()