2023-09-16 15:03:37 +08:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
|
|
torch::Tensor awq_gemm(
|
|
|
|
|
torch::Tensor _in_feats,
|
|
|
|
|
torch::Tensor _kernel,
|
|
|
|
|
torch::Tensor _scaling_factors,
|
|
|
|
|
torch::Tensor _zeros,
|
|
|
|
|
int split_k_iters);
|
|
|
|
|
|
2023-10-22 14:14:59 +08:00
|
|
|
void squeezellm_gemm(
|
|
|
|
|
torch::Tensor vec,
|
|
|
|
|
torch::Tensor mat,
|
|
|
|
|
torch::Tensor mul,
|
|
|
|
|
torch::Tensor lookup_table);
|
|
|
|
|
|
2023-09-16 15:03:37 +08:00
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
2023-10-22 14:14:59 +08:00
|
|
|
m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
|
|
|
|
m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
2023-09-16 15:03:37 +08:00
|
|
|
}
|