torch_ext/csrc/core.h
2024-11-18 19:54:12 +08:00

20 lines
1000 B
C

#ifndef CORE_H
#define CORE_H
#include <torch/extension.h>
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
// at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b);
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void print_idx();
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
void test_cute_tensor();
void md_mm(const torch::Tensor &src);
#endif