diff --git a/csrc/quantize.cu b/csrc/quantize.cu new file mode 100644 index 0000000..92fe22a --- /dev/null +++ b/csrc/quantize.cu @@ -0,0 +1,36 @@ +#include +#include + +__global__ void quantize(const half *src, __nv_fp8_storage_t *dest, int x_len, int y_len) +{ + int x_start = threadIdx.x * blockDim.x; + int y_start = threadIdx.y * blockDim.y; + __shared__ half max_value; + + max_value = __float2half(-10000.0f); + for (int i = 0; i < blockDim.x; i++) + { + for (int j = 0; j < blockDim.x; j++) + { + if (x_start + i < x_len && y_start + j < y_len) + { + + int real_offset = (y_start + j) * x_len + x_start + i; + max_value = __hmax(src[real_offset], max_value); + } + } + } + for (int i = 0; i < blockDim.x; i++) + { + for (int j = 0; j < blockDim.y; j++) + { + if (x_start + i < x_len && y_start + j < y_len) + { + + int real_offset = (y_start + j) * x_len + x_start + i; + half tmp = __hdiv(src[real_offset], max_value); + dest[real_offset] = __nv_cvt_halfraw_to_fp8(__nv_half_raw(tmp), __NV_SATFINITE, __NV_E5M2); + } + } + } +} \ No newline at end of file diff --git a/setup.py b/setup.py index 3b9bf4d..51a843b 100644 --- a/setup.py +++ b/setup.py @@ -11,6 +11,7 @@ files = [ "csrc/core_bind.cpp", "csrc/max.cu", "csrc/md.cu", + "csrc/quantize.cu", ] extension = CUDAExtension( name="torch_cuda_ext.core", diff --git a/tests/test_mma.py b/tests/test_mma.py new file mode 100644 index 0000000..da5f03e --- /dev/null +++ b/tests/test_mma.py @@ -0,0 +1,468 @@ +import torch + +import triton +import triton.language as tl + +DEVICE = "cuda" + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def is_hip_cdna2(): + target = triton.runtime.driver.active.get_current_target() + return target.backend == "hip" and target.arch == "gfx90a" + + +def get_cuda_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + # Good config for fp8 inputs. + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=3, + num_warps=8, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + ] + + +def get_hip_autotune_config(): + return [ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + }, + num_warps=4, + num_stages=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 16, + "GROUP_SIZE_M": 4, + "waves_per_eu": 2, + }, + num_warps=8, + num_stages=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "waves_per_eu": 2, + }, + num_warps=8, + num_stages=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "waves_per_eu": 3, + }, + num_warps=4, + num_stages=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "waves_per_eu": 8, + }, + num_warps=4, + num_stages=2, + ), + ] + + +def get_autotune_config(): + if is_cuda(): + return get_cuda_autotune_config() + else: + return get_hip_autotune_config() + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=get_autotune_config(), + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + # You can fuse arbitrary activation functions here + # while the accumulator is still in FP32! + if ACTIVATION == "leaky_relu": + accumulator = leaky_relu(accumulator) + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# We can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `matmul_kernel`. +@triton.jit +def leaky_relu(x): + return tl.where(x >= 0, x, 0.01 * x) + + +def matmul(a, b, activation=""): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + matmul_kernel[grid]( + a, + b, + c, # + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + ACTIVATION=activation, # + ) + return c + + +torch.manual_seed(0) +a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) +triton_output = matmul(a, b) +torch_output = torch.matmul(a, b) +print(f"triton_output_with_fp16_inputs={triton_output}") +print(f"torch_output_with_fp16_inputs={torch_output}") +# Bigger tolerance for AMD CDNA2 devices. +# CDNA2 devices use reduced precision fp16 and bf16 and flush input and +# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices +rtol = 1e-2 if is_hip_cdna2() else 0 +if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print("✅ Triton and Torch match") +else: + print("❌ Triton and Torch differ") + +TORCH_HAS_FP8 = hasattr(torch, "float8_e5m2") +if TORCH_HAS_FP8 and is_cuda(): + torch.manual_seed(0) + a = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + b = torch.randn((512, 512), device=DEVICE, dtype=torch.float16) + a = a.to(torch.float8_e5m2) + # pre-transpose b for efficiency. + b = b.T + b = b.to(torch.float8_e5m2) + triton_output = matmul(a, b) + torch_output = torch.matmul(a.to(torch.float16), b.to(torch.float16)) + print(f"triton_output_with_fp8_inputs={triton_output}") + print(f"torch_output_with_fp8_inputs={torch_output}") + if torch.allclose(triton_output, torch_output, atol=0.125, rtol=0): + print("✅ Triton and Torch match") + else: + print("❌ Triton and Torch differ") + +ref_lib = "cuBLAS" if is_cuda() else "rocBLAS" + +configs = [] +for fp8_inputs in [False, True]: + if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): + continue + configs.append( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=[ + 128 * i for i in range(2, 33) + ], # Different possible values for `x_name` + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=( + ["triton"] if fp8_inputs else [ref_lib.lower(), "triton"] + ), # Label name for the lines + line_names=["Triton"] if fp8_inputs else [ref_lib, "Triton"], # Line styles + styles=[("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance-" + + ( + "fp16" if not fp8_inputs else "fp8" + ), # Name for the plot, used also as a file name for saving the plot. + args={"fp8_inputs": fp8_inputs}, + ) + ) + + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider, fp8_inputs): + a = torch.randn((M, K), device=DEVICE, dtype=torch.float16) + b = torch.randn((K, N), device=DEVICE, dtype=torch.float16) + if TORCH_HAS_FP8 and fp8_inputs: + a = a.to(torch.float8_e5m2) + b = b.T + b = b.to(torch.float8_e5m2) + quantiles = [0.5, 0.2, 0.8] + if provider == ref_lib.lower(): + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a, b), quantiles=quantiles + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b), quantiles=quantiles + ) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +benchmark.run(show_plots=True, print_data=True) diff --git a/tests/test_profille.py b/tests/test_profille.py new file mode 100644 index 0000000..70a0ece --- /dev/null +++ b/tests/test_profille.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.profiler import profile, record_function, ProfilerActivity + +# 定义模型和优化器 +model = nn.Linear(100, 10).cuda() +optimizer = optim.SGD(model.parameters(), lr=0.01) + +# 启动 Profiler +with profile( + activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU], # 监控 GPU 和 CPU + record_shapes=True, # 记录张量形状 + profile_memory=True, # 分析内存使用 + with_stack=True, # 记录调用栈 +) as prof: + for _ in range(10): + x = torch.randn(64, 100).cuda() + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + optimizer.zero_grad() + +# 输出分析结果 +print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) +prof.export_chrome_trace("./trace.json") diff --git a/tests/test_triton.py b/tests/test_triton.py new file mode 100644 index 0000000..7755d16 --- /dev/null +++ b/tests/test_triton.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def softmax_kernel( + output_ptr, + input_ptr, + input_row_stride, + output_row_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + # 获取当前程序的行索引 + row_idx = tl.program_id(0) + + # 计算输入和输出行的起始指针 + row_start_ptr = input_ptr + row_idx * input_row_stride + output_row_start_ptr = output_ptr + row_idx * output_row_stride + + # 将输入数据加载到本地内存 + row_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + row_offsets + row = tl.load(input_ptrs, mask=row_offsets < n_cols, other=-float("inf")) + + # 计算 Softmax + row_minus_max = row - tl.max(row, axis=0) # 数值稳定性:减去最大值 + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + + # 将结果写回输出 + output_ptrs = output_row_start_ptr + row_offsets + tl.store(output_ptrs, softmax_output, mask=row_offsets < n_cols) + + +def softmax(x): + n_rows, n_cols = x.shape + + # 分配输出张量 + output = torch.empty_like(x) + + # 定义 GPU 内核的网格和块大小 + BLOCK_SIZE = triton.next_power_of_2(n_cols) + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + + # 启动 Triton 内核 + softmax_kernel[(n_rows,)]( + output, + x, + x.stride(0), + output.stride(0), + n_cols, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return output + + +# 测试 Softmax +if __name__ == "__main__": + # 创建一个随机矩阵 + x = torch.randn(4, 16, device="cuda") + + # 使用 Triton 计算 Softmax + output_triton = softmax(x) + + # 使用 PyTorch 计算 Softmax 作为参考 + output_torch = torch.softmax(x, dim=1) + + # 检查结果是否一致 + print("Input:") + print(x) + print("Triton Softmax:") + print(output_triton) + print("PyTorch Softmax:") + print(output_torch) + print(f"Are close: {torch.allclose(output_triton, output_torch, atol=1e-5)}") diff --git a/tests/test_triton_mma.py b/tests/test_triton_mma.py new file mode 100644 index 0000000..86214f5 --- /dev/null +++ b/tests/test_triton_mma.py @@ -0,0 +1,176 @@ +# coding=utf-8 +import torch + +import triton +import triton.language as tl + + +@triton.jit +def add_kernel(a_ptr, b_ptr, c_ptr, numel, BLOCK_SIZE: tl.constexpr): + xidx = tl.program_id(0) + index = xidx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = index < numel + a = tl.load(a_ptr + index, mask=mask) + b = tl.load(b_ptr + index, mask=mask) + c = a + b + tl.store(c_ptr + index, c, mask=mask) + + +@triton.jit +def add_mat_kernel( + a_ptr, + b_ptr, + c_ptr, + stride_m, + stride_n, + m, + n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + midx = tl.program_id(0) + nidx = tl.program_id(1) + m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) + index = m_offset[:, None] * stride_m + n_offset[None, :] * stride_n + a = tl.load(a_ptr + index, mask=mask) + b = tl.load(b_ptr + index, mask=mask) + c = a + b + tl.store(c_ptr + index, c, mask=mask) + + +@triton.jit +def threed_mat_kernel( + a_ptr, + b_ptr, + c_ptr, + stride_1, + stride_m, + stride_n, + num_token, + m, + n, + TOKEN_BLOCK: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + token_idx = tl.program_id(0) + midx = tl.program_id(1) + nidx = tl.program_id(2) + # tl.device_print("token idx:", token_idx) + m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (m_offset[:, None] < m) & (n_offset[None, :] < n) + index = ( + token_idx * stride_1 + + m_offset[:, None] * stride_m + + n_offset[None, :] * stride_n + ) + a = tl.load(a_ptr + index, mask=mask) + b = tl.load(b_ptr + index, mask=mask) + c = a + b + tl.store(c_ptr + index, c, mask=mask) + + +@triton.jit +def mma_kernel( + a_ptr, + b_ptr, + c_ptr, + m, + n, + k, + stride_am, + stride_an, + stride_bm, + stride_bk, + stride_cm, + stride_ck, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + midx = tl.program_id(0) + nidx = tl.program_id(1) + a_m_offset = midx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + a_n_offset = nidx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + a_mask = (a_m_offset[:, None] < m) & (a_n_offset[None, :] < n) + a_index = a_m_offset[:, None] * stride_am + a_n_offset[None, :] * stride_an + a = tl.load(a_ptr + a_index, mask=a_mask) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + +def test_add_kernel(): + a = torch.randn(size=(1024,), device="cuda") + b = torch.randn(size=(1024,), device="cuda") + c = torch.empty_like(a) + BLOCK_SIZE = 32 + grid = lambda meta: (triton.cdiv(a.numel(), meta["BLOCK_SIZE"]),) + add_kernel[grid](a, b, c, a.numel(), BLOCK_SIZE) + real_c = a + b + assert torch.allclose(real_c, c), "not equal" + print("all right") + + +def test_add_mat_kernel(): + a = torch.randn(size=(127, 255), device="cuda") + b = torch.randn(size=(127, 255), device="cuda") + c = torch.empty_like(a) + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 16 + grid = lambda meta: ( + triton.cdiv(a.size(0), meta["BLOCK_SIZE_M"]), + triton.cdiv(a.size(1), meta["BLOCK_SIZE_N"]), + ) + add_mat_kernel[grid]( + a, + b, + c, + a.stride(0), + a.stride(1), + a.size(0), + a.size(1), + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + real_c = a + b + assert torch.allclose(c, real_c), "not equal" + print("all right") + + +def test_three_dimension(): + num_token = 128 + a = torch.randn(size=(num_token, 127, 255), device="cuda") + b = torch.randn(size=(num_token, 127, 255), device="cuda") + c = torch.empty_like(a) + BLOCK_SIZE_M = 32 + BLOCK_SIZE_N = 16 + TOKEN_BLOCK = a.size(0) + grid = lambda meta: ( + a.size(0), + triton.cdiv(a.size(1), meta["BLOCK_SIZE_M"]), + triton.cdiv(a.size(2), meta["BLOCK_SIZE_N"]), + ) + threed_mat_kernel[grid]( + a, + b, + c, + a.stride(0), + a.stride(1), + a.stride(2), + a.size(0), + a.size(1), + a.size(2), + TOKEN_BLOCK, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + ) + real_c = a + b + assert torch.allclose(c, real_c), "not equal" + print("all right") + + +if __name__ == "__main__": + test_add_kernel() + test_add_mat_kernel() + test_three_dimension()