30 lines
878 B
Plaintext
30 lines
878 B
Plaintext
|
|
// matrix_add.cu
|
||
|
|
#include <torch/extension.h>
|
||
|
|
#include <cuda_fp16.h>
|
||
|
|
#include "matrix_add.h"
|
||
|
|
|
||
|
|
__global__ void matrix_add_kernel(at::Half *out, const at::Half *a, const at::Half *b, int rows, int cols)
|
||
|
|
{
|
||
|
|
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||
|
|
int col = blockIdx.x * blockDim.x + threadIdx.x;
|
||
|
|
|
||
|
|
if (row < rows && col < cols)
|
||
|
|
{
|
||
|
|
out[row * cols + col] = a[row * cols + col] + b[row * cols + col];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
void matrix_add(torch::Tensor a, torch::Tensor b, torch::Tensor out)
|
||
|
|
{
|
||
|
|
auto a_ptr = a.data_ptr<at::Half>();
|
||
|
|
auto b_ptr = b.data_ptr<at::Half>();
|
||
|
|
auto out_ptr = out.data_ptr<at::Half>();
|
||
|
|
|
||
|
|
int rows = a.size(0);
|
||
|
|
int cols = a.size(1);
|
||
|
|
|
||
|
|
dim3 threads(16, 16);
|
||
|
|
dim3 grid(ceil(a.size(1) / float(threads.x)), ceil(a.size(0) / float(threads.y)));
|
||
|
|
|
||
|
|
matrix_add_kernel<<<grid, threads>>>(out_ptr, a_ptr, b_ptr, rows, cols);
|
||
|
|
}
|