[Bugfix] Make torch registration of punica ops optional (#7970)
This commit is contained in:
parent
fdd9daafa3
commit
3cdfe1f38b
@ -160,6 +160,9 @@ def _bgmv_expand(
|
||||
return
|
||||
|
||||
|
||||
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
||||
_bgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
||||
_bgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_expand = _bgmv_expand
|
||||
|
||||
@ -173,6 +173,9 @@ def _bgmv_expand_slice(
|
||||
return
|
||||
|
||||
|
||||
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
||||
_bgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
||||
_bgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_expand_slice = _bgmv_expand_slice
|
||||
|
||||
@ -142,6 +142,9 @@ def _bgmv_shrink(
|
||||
return
|
||||
|
||||
|
||||
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||
_bgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||
_bgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_shrink = _bgmv_shrink
|
||||
|
||||
@ -192,6 +192,9 @@ def _sgmv_expand(
|
||||
return
|
||||
|
||||
|
||||
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||
_sgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||
_sgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_expand = _sgmv_expand
|
||||
|
||||
@ -205,6 +205,9 @@ def _sgmv_expand_slice(
|
||||
return
|
||||
|
||||
|
||||
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
||||
_sgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
||||
_sgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_expand_slice = _sgmv_expand_slice
|
||||
|
||||
@ -189,6 +189,9 @@ def _sgmv_shrink(
|
||||
return
|
||||
|
||||
|
||||
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
||||
_sgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
try:
|
||||
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
||||
_sgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_shrink = _sgmv_shrink
|
||||
|
||||
@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
from vllm.utils import is_xpu
|
||||
|
||||
# FIXME: xpu path doesn't support torch.library.custom_op
|
||||
if HAS_TRITON and not is_xpu():
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
|
||||
Loading…
Reference in New Issue
Block a user