#include #include "core.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("add_two_tensors", &add_two_tensors, "add two tensor sum"); m.def("rope_tensors", &rope_tensors, "rope emabedding."); m.def("matmul", &matmul, "matmul"); m.def("matmul_sigmoid", &matmul_sigmoid, "matmul_sigmoid"); m.def("matmul_shared", &matmul_shared, "matmul_shared"); // m.def("gemm_mma", &gemm_mma, "gemm_mma"); m.def("org_mm", &org_mm, "org_mm"); m.def("org_mm_shared", &org_mm_shared, "org_mm"); m.def("org_mm_shared_half", &org_mm_shared_half, "org_mm shared and half precision"); m.def("print_idx", &print_idx, "just_printidx"); m.def("reducemax", &reducemax, "reduce max"); m.def("test_cute_tensor", &test_cute_tensor, "just test cute tensor"); m.def("md_mm", &md_mm, "just a test of multi dimension mm"); m.def("block_sum", &block_sum, "test block sum"); m.def("md_block_sum", &md_block_sum, "multi dimension block sum"); m.def("rms_norm", &rms_norm, "rms noram"); }