[Bugfix] Make torch registration of punica ops optional (#7970)

This commit is contained in:
bnellnm 2024-08-28 18:11:49 -04:00 committed by GitHub
parent fdd9daafa3
commit 3cdfe1f38b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 37 additions and 21 deletions

View File

@ -160,6 +160,9 @@ def _bgmv_expand(
return return
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand", bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand, _bgmv_expand,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand

View File

@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return return
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice, _bgmv_expand_slice,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice

View File

@ -142,6 +142,9 @@ def _bgmv_shrink(
return return
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink, _bgmv_shrink,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink

View File

@ -192,6 +192,9 @@ def _sgmv_expand(
return return
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand", sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand, _sgmv_expand,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand

View File

@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return return
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice, _sgmv_expand_slice,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice

View File

@ -189,6 +189,9 @@ def _sgmv_shrink(
return return
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink, _sgmv_shrink,
mutates_args=["output_tensor"]) mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink

View File

@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch import torch
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu
# FIXME: xpu path doesn't support torch.library.custom_op if HAS_TRITON:
if HAS_TRITON and not is_xpu():
from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.bgmv_shrink import bgmv_shrink