8 lines
168 B
C++
8 lines
168 B
C++
// setup.cpp
|
|
#include <torch/extension.h>
|
|
#include "matrix_add.h"
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
{
|
|
m.def("matrix_add", &matrix_add, "FP16 Matrix Add");
|
|
} |