torch_ext/matrix_add.cu

30 lines
878 B
Plaintext
Raw Normal View History

2024-11-16 19:26:54 +08:00
// matrix_add.cu
#include <torch/extension.h>
#include <cuda_fp16.h>
#include "matrix_add.h"
__global__ void matrix_add_kernel(at::Half *out, const at::Half *a, const at::Half *b, int rows, int cols)
{
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < rows && col < cols)
{
out[row * cols + col] = a[row * cols + col] + b[row * cols + col];
}
}
void matrix_add(torch::Tensor a, torch::Tensor b, torch::Tensor out)
{
auto a_ptr = a.data_ptr<at::Half>();
auto b_ptr = b.data_ptr<at::Half>();
auto out_ptr = out.data_ptr<at::Half>();
int rows = a.size(0);
int cols = a.size(1);
dim3 threads(16, 16);
dim3 grid(ceil(a.size(1) / float(threads.x)), ceil(a.size(0) / float(threads.y)));
matrix_add_kernel<<<grid, threads>>>(out_ptr, a_ptr, b_ptr, rows, cols);
}