53 lines
2.4 KiB
C
53 lines
2.4 KiB
C
#ifndef CORE_H
|
|
#define CORE_H
|
|
#include <torch/extension.h>
|
|
|
|
#define TYPING_DISPATCH(scalar_t, ...) \
|
|
switch (scalar_t) \
|
|
{ \
|
|
case at::ScalarType::Float: \
|
|
{ \
|
|
using fi_type = float; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
case at::ScalarType::BFloat16: \
|
|
{ \
|
|
using fi_type = __nv_bfloat16; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
case at::ScalarType::Half: \
|
|
{ \
|
|
using fi_type = __half; \
|
|
__VA_ARGS__(); \
|
|
} \
|
|
}
|
|
// default: \
|
|
// printf("do not support such type\n"); \
|
|
// }
|
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
|
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
|
|
|
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
|
|
|
void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
|
|
void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start);
|
|
|
|
void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output);
|
|
// at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b);
|
|
void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c);
|
|
void print_idx();
|
|
void reducemax(const torch::Tensor &src, torch::Tensor &dest);
|
|
void test_cute_tensor();
|
|
void md_mm(const torch::Tensor &src);
|
|
void block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
|
void md_block_sum(const torch::Tensor &src, torch::Tensor &dest);
|
|
void rms_norm(torch::Tensor &states, float eps, float gamma);
|
|
#endif |