[torch.compile] fix cpu broken code (#9947)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-01 23:35:47 -07:00 committed by GitHub
parent a78dd3303e
commit af7380d83b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1551,7 +1551,14 @@ def direct_register_custom_op(
"""
if is_in_doc_build():
return
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
import torch.library
if hasattr(torch.library, "infer_schema"):
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")