diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index a27f1e7c..ff037756 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -89,6 +89,10 @@ torch::Tensor prepack_B(torch::Tensor const& B, TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); m.impl("machete_gemm", &gemm); +} + +// use CatchAll since supported_schedules has no tensor arguments +TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) { m.impl("machete_supported_schedules", &supported_schedules); }