[Bugfix] Make torch registration of punica ops optional (#7970)

This commit is contained in:
bnellnm 2024-08-28 18:11:49 -04:00 committed by GitHub
parent fdd9daafa3
commit 3cdfe1f38b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 21 deletions

View File

@ -160,6 +160,9 @@ def _bgmv_expand(
return
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand

View File

@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice

View File

@ -142,6 +142,9 @@ def _bgmv_shrink(
return
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink

View File

@ -192,6 +192,9 @@ def _sgmv_expand(
return
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand

View File

@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice

View File

@ -189,6 +189,9 @@ def _sgmv_shrink(
return
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink

View File

@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch
from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu
# FIXME: xpu path doesn't support torch.library.custom_op
if HAS_TRITON and not is_xpu():
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink