[Misc] remove peft as dependency for prompt models (#8162)
This commit is contained in:
parent
5faedf1b62
commit
04e7c4e771
@ -1558,14 +1558,6 @@ class PromptAdapterConfig:
|
|||||||
prompt_adapter_dtype: Optional[torch.dtype] = None
|
prompt_adapter_dtype: Optional[torch.dtype] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
library_name = 'peft'
|
|
||||||
try:
|
|
||||||
__import__(library_name)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
f"'{library_name}' is not installed for prompt adapter support."
|
|
||||||
f"Please install it using 'pip install {library_name}'."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if self.max_prompt_adapters < 1:
|
if self.max_prompt_adapters < 1:
|
||||||
raise ValueError(f"max_prompt_adapters "
|
raise ValueError(f"max_prompt_adapters "
|
||||||
|
|||||||
@ -14,6 +14,7 @@ from vllm.config import PromptAdapterConfig
|
|||||||
from vllm.prompt_adapter.layers import (
|
from vllm.prompt_adapter.layers import (
|
||||||
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
VocabParallelEmbeddingWithPromptAdapter) # yapf: disable
|
||||||
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
from vllm.prompt_adapter.layers import PromptAdapterMapping
|
||||||
|
from vllm.prompt_adapter.utils import load_peft_weights
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -90,7 +91,6 @@ class PromptAdapterModel(AdapterModel):
|
|||||||
config: PromptAdapterConfig,
|
config: PromptAdapterConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
) -> "PromptAdapterModel":
|
) -> "PromptAdapterModel":
|
||||||
from peft.utils import load_peft_weights
|
|
||||||
|
|
||||||
if num_virtual_tokens > config.max_prompt_adapter_token:
|
if num_virtual_tokens > config.max_prompt_adapter_token:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
93
vllm/prompt_adapter/utils.py
Normal file
93
vllm/prompt_adapter/utils.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import file_exists, hf_hub_download
|
||||||
|
from huggingface_hub.utils import EntryNotFoundError
|
||||||
|
from safetensors.torch import load_file as safe_load_file
|
||||||
|
|
||||||
|
WEIGHTS_NAME = "adapter_model.bin"
|
||||||
|
SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
# Get current device name based on available devices
|
||||||
|
def infer_device() -> str:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def load_peft_weights(model_id: str,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
**hf_hub_download_kwargs) -> dict:
|
||||||
|
r"""
|
||||||
|
A helper method to load the PEFT weights from the HuggingFace Hub or locally
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id (`str`):
|
||||||
|
The local path to the adapter weights or the name of the adapter to
|
||||||
|
load from the HuggingFace Hub.
|
||||||
|
device (`str`):
|
||||||
|
The device to load the weights onto.
|
||||||
|
hf_hub_download_kwargs (`dict`):
|
||||||
|
Additional arguments to pass to the `hf_hub_download` method when
|
||||||
|
loading from the HuggingFace Hub.
|
||||||
|
"""
|
||||||
|
path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
|
||||||
|
if hf_hub_download_kwargs.get("subfolder", None) is not None else
|
||||||
|
model_id)
|
||||||
|
|
||||||
|
if device is None:
|
||||||
|
device = infer_device()
|
||||||
|
|
||||||
|
if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
|
||||||
|
filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
use_safetensors = True
|
||||||
|
elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
|
||||||
|
filename = os.path.join(path, WEIGHTS_NAME)
|
||||||
|
use_safetensors = False
|
||||||
|
else:
|
||||||
|
token = hf_hub_download_kwargs.get("token", None)
|
||||||
|
if token is None:
|
||||||
|
token = hf_hub_download_kwargs.get("use_auth_token", None)
|
||||||
|
|
||||||
|
hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"],
|
||||||
|
SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
if hf_hub_download_kwargs.get("subfolder", None)
|
||||||
|
is not None else SAFETENSORS_WEIGHTS_NAME)
|
||||||
|
has_remote_safetensors_file = file_exists(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=hub_filename,
|
||||||
|
revision=hf_hub_download_kwargs.get("revision", None),
|
||||||
|
repo_type=hf_hub_download_kwargs.get("repo_type", None),
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
use_safetensors = has_remote_safetensors_file
|
||||||
|
|
||||||
|
if has_remote_safetensors_file:
|
||||||
|
# Priority 1: load safetensors weights
|
||||||
|
filename = hf_hub_download(
|
||||||
|
model_id,
|
||||||
|
SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
**hf_hub_download_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
filename = hf_hub_download(model_id, WEIGHTS_NAME,
|
||||||
|
**hf_hub_download_kwargs)
|
||||||
|
except EntryNotFoundError:
|
||||||
|
raise ValueError( # noqa: B904
|
||||||
|
f"Can't find weights for {model_id} in {model_id} or \
|
||||||
|
in the Hugging Face Hub. "
|
||||||
|
f"Please check that the file {WEIGHTS_NAME} or \
|
||||||
|
{SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.")
|
||||||
|
|
||||||
|
if use_safetensors:
|
||||||
|
adapters_weights = safe_load_file(filename, device=device)
|
||||||
|
else:
|
||||||
|
adapters_weights = torch.load(filename,
|
||||||
|
map_location=torch.device(device))
|
||||||
|
|
||||||
|
return adapters_weights
|
||||||
Loading…
Reference in New Issue
Block a user