torch_ext/csrc/core_bind.cpp

23 lines
1015 B
C++

#include <torch/extension.h>
#include "core.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("add_two_tensors", &add_two_tensors, "add two tensor sum");
m.def("rope_tensors", &rope_tensors, "rope emabedding.");
m.def("matmul", &matmul, "matmul");
m.def("matmul_sigmoid", &matmul_sigmoid, "matmul_sigmoid");
m.def("matmul_shared", &matmul_shared, "matmul_shared");
// m.def("gemm_mma", &gemm_mma, "gemm_mma");
m.def("org_mm", &org_mm, "org_mm");
m.def("org_mm_shared", &org_mm_shared, "org_mm");
m.def("org_mm_shared_half", &org_mm_shared_half, "org_mm shared and half precision");
m.def("print_idx", &print_idx, "just_printidx");
m.def("reducemax", &reducemax, "reduce max");
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
m.def("block_sum", &block_sum, "test block sum");
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
m.def("softmax", &softmax, "test softmax example");
}