22 lines
961 B
C++
22 lines
961 B
C++
#include <torch/extension.h>
|
|
#include "core.h"
|
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
|
{
|
|
m.def("add_two_tensors", &add_two_tensors, "add two tensor sum");
|
|
m.def("rope_tensors", &rope_tensors, "rope emabedding.");
|
|
m.def("matmul", &matmul, "matmul");
|
|
m.def("matmul_sigmoid", &matmul_sigmoid, "matmul_sigmoid");
|
|
m.def("matmul_shared", &matmul_shared, "matmul_shared");
|
|
// m.def("gemm_mma", &gemm_mma, "gemm_mma");
|
|
m.def("org_mm", &org_mm, "org_mm");
|
|
m.def("org_mm_shared", &org_mm_shared, "org_mm");
|
|
m.def("org_mm_shared_half", &org_mm_shared_half, "org_mm shared and half precision");
|
|
m.def("print_idx", &print_idx, "just_printidx");
|
|
m.def("reducemax", &reducemax, "reduce max");
|
|
m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor");
|
|
m.def("md_mm", &md_mm, "just a test of multi dimension mm");
|
|
m.def("block_sum", &block_sum, "test block sum");
|
|
m.def("md_block_sum", &md_block_sum, "multi dimension block sum");
|
|
}
|