torch_ext/csrc/core.h
2024-12-14 13:34:30 +08:00

53 lines
2.4 KiB
C

#ifndef CORE_H
#define CORE_H
#include <torch/extension.h>
#define TYPING_DISPATCH(scalar_t, ...) \
switch (scalar_t) \
{ \
case at::ScalarType::Float: \
{ \
using fi_type = float; \
__VA_ARGS__(); \
} \
case at::ScalarType::BFloat16: \
{ \
using fi_type = __nv_bfloat16; \
__VA_ARGS__(); \
} \
case at::ScalarType::Half: \
{ \
using fi_type = __half; \
__VA_ARGS__(); \
} \
}
// default: \
// printf("do not support such type\n"); \
// }
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
// at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b);
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
void print_idx();
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
void test_cute_tensor();
void md_mm(const torch::Tensor &src);
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
void rms_norm(torch::Tensor &states, float eps, float gamma);
#endif