[Bugfix] Make torch registration of punica ops optional (#7970)
This commit is contained in:
parent
fdd9daafa3
commit
3cdfe1f38b
@ -160,6 +160,9 @@ def _bgmv_expand(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
try:
|
||||||
_bgmv_expand,
|
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
||||||
mutates_args=["output_tensor"])
|
_bgmv_expand,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
bgmv_expand = _bgmv_expand
|
||||||
|
|||||||
@ -173,6 +173,9 @@ def _bgmv_expand_slice(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
try:
|
||||||
_bgmv_expand_slice,
|
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
||||||
mutates_args=["output_tensor"])
|
_bgmv_expand_slice,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
bgmv_expand_slice = _bgmv_expand_slice
|
||||||
|
|||||||
@ -142,6 +142,9 @@ def _bgmv_shrink(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
try:
|
||||||
_bgmv_shrink,
|
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||||
mutates_args=["output_tensor"])
|
_bgmv_shrink,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
bgmv_shrink = _bgmv_shrink
|
||||||
|
|||||||
@ -192,6 +192,9 @@ def _sgmv_expand(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
try:
|
||||||
_sgmv_expand,
|
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||||
mutates_args=["output_tensor"])
|
_sgmv_expand,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
sgmv_expand = _sgmv_expand
|
||||||
|
|||||||
@ -205,6 +205,9 @@ def _sgmv_expand_slice(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
try:
|
||||||
_sgmv_expand_slice,
|
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
||||||
mutates_args=["output_tensor"])
|
_sgmv_expand_slice,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
sgmv_expand_slice = _sgmv_expand_slice
|
||||||
|
|||||||
@ -189,6 +189,9 @@ def _sgmv_shrink(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
try:
|
||||||
_sgmv_shrink,
|
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
||||||
mutates_args=["output_tensor"])
|
_sgmv_shrink,
|
||||||
|
mutates_args=["output_tensor"])
|
||||||
|
except AttributeError:
|
||||||
|
sgmv_shrink = _sgmv_shrink
|
||||||
|
|||||||
@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
from vllm.utils import is_xpu
|
|
||||||
|
|
||||||
# FIXME: xpu path doesn't support torch.library.custom_op
|
if HAS_TRITON:
|
||||||
if HAS_TRITON and not is_xpu():
|
|
||||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user