// matrix_add.cu #include #include #include "matrix_add.h" __global__ void matrix_add_kernel(at::Half *out, const at::Half *a, const at::Half *b, int rows, int cols) { int row = blockIdx.y * blockDim.y + threadIdx.y; int col = blockIdx.x * blockDim.x + threadIdx.x; if (row < rows && col < cols) { out[row * cols + col] = a[row * cols + col] + b[row * cols + col]; } } void matrix_add(torch::Tensor a, torch::Tensor b, torch::Tensor out) { auto a_ptr = a.data_ptr(); auto b_ptr = b.data_ptr(); auto out_ptr = out.data_ptr(); int rows = a.size(0); int cols = a.size(1); dim3 threads(16, 16); dim3 grid(ceil(a.size(1) / float(threads.x)), ceil(a.size(0) / float(threads.y))); matrix_add_kernel<<>>(out_ptr, a_ptr, b_ptr, rows, cols); }