19 lines
962 B
C
19 lines
962 B
C
#ifndef CORE_H
|
|
#define CORE_H
|
|
#include <torch/extension.h>
|
|
|
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
|
|
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
|
|
|
|
void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
// at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b);
|
|
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void print_idx();
|
|
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
|
void test_cute_tensor();
|
|
#endif |