#ifndef CORE_H #define CORE_H #include #define TYPING_DISPATCH(scalar_t, ...) \ switch (scalar_t) \ { \ case at::ScalarType::Float: \ { \ using fi_type = float; \ __VA_ARGS__(); \ } \ case at::ScalarType::BFloat16: \ { \ using fi_type = __nv_bfloat16; \ __VA_ARGS__(); \ } \ case at::ScalarType::Half: \ { \ using fi_type = __half; \ __VA_ARGS__(); \ } \ } // default: \ // printf("do not support such type\n"); \ // } #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) void add_two_tensors(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); void rope_tensors(const torch::Tensor &input, torch::Tensor &output, int rope_index_start); void matmul(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); void matmul_sigmoid(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); void matmul_shared(const torch::Tensor &input1, const torch::Tensor &input2, torch::Tensor &output); // at::Tensor &gemm_mma(const at::Tensor &a, const at::Tensor &b); void org_mm(const at::Tensor &a, const at::Tensor &b, at::Tensor &c); void org_mm_shared(const at::Tensor &a, const at::Tensor &b, at::Tensor &c); void org_mm_shared_half(const at::Tensor &a, const at::Tensor &b, at::Tensor &c); void print_idx(); void reducemax(const torch::Tensor &src, torch::Tensor &dest); void test_cute_tensor(); void md_mm(const torch::Tensor &src); void block_sum(const torch::Tensor &src, torch::Tensor &dest); void md_block_sum(const torch::Tensor &src, torch::Tensor &dest); void rms_norm(torch::Tensor &states, float eps, float gamma); #endif