[Core] Support dynamically loading Lora adapter from HuggingFace (#6234)
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
This commit is contained in:
parent
69d5ae38dc
commit
42c7f66a38
@ -462,7 +462,7 @@ def test_prefill_schedule_max_lora():
|
|||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
lora_name=str(i),
|
lora_name=str(i),
|
||||||
lora_int_id=i + 1,
|
lora_int_id=i + 1,
|
||||||
lora_local_path="abc"))
|
lora_path="abc"))
|
||||||
waiting.append(seq_group)
|
waiting.append(seq_group)
|
||||||
# Add two more requests to verify lora is prioritized.
|
# Add two more requests to verify lora is prioritized.
|
||||||
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
# 0: Lora, 1: Lora, 2: regular, 3: regular
|
||||||
@ -760,7 +760,7 @@ def test_schedule_swapped_max_loras():
|
|||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
lora_name=str(i),
|
lora_name=str(i),
|
||||||
lora_int_id=i + 1,
|
lora_int_id=i + 1,
|
||||||
lora_local_path="abc"))
|
lora_path="abc"))
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||||
|
|||||||
@ -159,8 +159,14 @@ def dummy_model_gate_up() -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sql_lora_files():
|
def sql_lora_huggingface_id():
|
||||||
return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
|
# huggingface repo id is used to test lora runtime downloading.
|
||||||
|
return "yard1/llama-2-7b-sql-lora-test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def sql_lora_files(sql_lora_huggingface_id):
|
||||||
|
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
@ -29,7 +29,7 @@ def _create_lora_request(lora_id, long_context_infos):
|
|||||||
context_len = long_context_infos[lora_id]["context_length"]
|
context_len = long_context_infos[lora_id]["context_length"]
|
||||||
scaling_factor = context_len_to_scaling_factor[context_len]
|
scaling_factor = context_len_to_scaling_factor[context_len]
|
||||||
return LoRARequest(context_len, lora_id,
|
return LoRARequest(context_len, lora_id,
|
||||||
long_context_infos[lora_id]["lora"],
|
long_context_infos[lora_id]["lora"], None,
|
||||||
4096 * scaling_factor)
|
4096 * scaling_factor)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
39
tests/lora/test_lora_huggingface.py
Normal file
39
tests/lora/test_lora_huggingface.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.lora.models import LoRAModel
|
||||||
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
|
|
||||||
|
# Provide absolute path and huggingface lora ids
|
||||||
|
lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
|
||||||
|
def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
|
||||||
|
lora_name = request.getfixturevalue(lora_fixture_name)
|
||||||
|
supported_lora_modules = LlamaForCausalLM.supported_lora_modules
|
||||||
|
packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
|
||||||
|
embedding_modules = LlamaForCausalLM.embedding_modules
|
||||||
|
embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
|
||||||
|
expected_lora_modules: List[str] = []
|
||||||
|
for module in supported_lora_modules:
|
||||||
|
if module in packed_modules_mapping:
|
||||||
|
expected_lora_modules.extend(packed_modules_mapping[module])
|
||||||
|
else:
|
||||||
|
expected_lora_modules.append(module)
|
||||||
|
|
||||||
|
lora_path = get_adapter_absolute_path(lora_name)
|
||||||
|
|
||||||
|
# lora loading should work for either absolute path and hugggingface id.
|
||||||
|
lora_model = LoRAModel.from_local_checkpoint(
|
||||||
|
lora_path,
|
||||||
|
expected_lora_modules,
|
||||||
|
lora_model_id=1,
|
||||||
|
device="cpu",
|
||||||
|
embedding_modules=embedding_modules,
|
||||||
|
embedding_padding_modules=embed_padding_modules)
|
||||||
|
|
||||||
|
# Assertions to ensure the model is loaded correctly
|
||||||
|
assert lora_model is not None, "LoRAModel is not loaded correctly"
|
||||||
@ -1,9 +1,12 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from huggingface_hub.utils import HfHubHTTPError
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||||
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
from vllm.utils import LRUCache
|
from vllm.utils import LRUCache
|
||||||
|
|
||||||
|
|
||||||
@ -182,3 +185,55 @@ def test_lru_cache():
|
|||||||
assert 2 in cache
|
assert 2 in cache
|
||||||
assert 4 in cache
|
assert 4 in cache
|
||||||
assert 6 in cache
|
assert 6 in cache
|
||||||
|
|
||||||
|
|
||||||
|
# Unit tests for get_adapter_absolute_path
|
||||||
|
@patch('os.path.isabs')
|
||||||
|
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
||||||
|
path = '/absolute/path/to/lora'
|
||||||
|
mock_isabs.return_value = True
|
||||||
|
assert get_adapter_absolute_path(path) == path
|
||||||
|
|
||||||
|
|
||||||
|
@patch('os.path.expanduser')
|
||||||
|
def test_get_adapter_absolute_path_expanduser(mock_expanduser):
|
||||||
|
# Path with ~ that needs to be expanded
|
||||||
|
path = '~/relative/path/to/lora'
|
||||||
|
absolute_path = '/home/user/relative/path/to/lora'
|
||||||
|
mock_expanduser.return_value = absolute_path
|
||||||
|
assert get_adapter_absolute_path(path) == absolute_path
|
||||||
|
|
||||||
|
|
||||||
|
@patch('os.path.exists')
|
||||||
|
@patch('os.path.abspath')
|
||||||
|
def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
|
||||||
|
# Relative path that exists locally
|
||||||
|
path = 'relative/path/to/lora'
|
||||||
|
absolute_path = '/absolute/path/to/lora'
|
||||||
|
mock_exist.return_value = True
|
||||||
|
mock_abspath.return_value = absolute_path
|
||||||
|
assert get_adapter_absolute_path(path) == absolute_path
|
||||||
|
|
||||||
|
|
||||||
|
@patch('huggingface_hub.snapshot_download')
|
||||||
|
@patch('os.path.exists')
|
||||||
|
def test_get_adapter_absolute_path_huggingface(mock_exist,
|
||||||
|
mock_snapshot_download):
|
||||||
|
# Hugging Face model identifier
|
||||||
|
path = 'org/repo'
|
||||||
|
absolute_path = '/mock/snapshot/path'
|
||||||
|
mock_exist.return_value = False
|
||||||
|
mock_snapshot_download.return_value = absolute_path
|
||||||
|
assert get_adapter_absolute_path(path) == absolute_path
|
||||||
|
|
||||||
|
|
||||||
|
@patch('huggingface_hub.snapshot_download')
|
||||||
|
@patch('os.path.exists')
|
||||||
|
def test_get_adapter_absolute_path_huggingface_error(mock_exist,
|
||||||
|
mock_snapshot_download):
|
||||||
|
# Hugging Face model identifier with download error
|
||||||
|
path = 'org/repo'
|
||||||
|
mock_exist.return_value = False
|
||||||
|
mock_snapshot_download.side_effect = HfHubHTTPError(
|
||||||
|
"failed to query model info")
|
||||||
|
assert get_adapter_absolute_path(path) == path
|
||||||
|
|||||||
@ -43,7 +43,7 @@ class PromptAdapterPath:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class LoRAModulePath:
|
class LoRAModulePath:
|
||||||
name: str
|
name: str
|
||||||
local_path: str
|
path: str
|
||||||
|
|
||||||
|
|
||||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||||
@ -83,7 +83,7 @@ class OpenAIServing:
|
|||||||
LoRARequest(
|
LoRARequest(
|
||||||
lora_name=lora.name,
|
lora_name=lora.name,
|
||||||
lora_int_id=i,
|
lora_int_id=i,
|
||||||
lora_local_path=lora.local_path,
|
lora_path=lora.path,
|
||||||
) for i, lora in enumerate(lora_modules, start=1)
|
) for i, lora in enumerate(lora_modules, start=1)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.adapter_commons.request import AdapterRequest
|
from vllm.adapter_commons.request import AdapterRequest
|
||||||
@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
|
|||||||
|
|
||||||
lora_name: str
|
lora_name: str
|
||||||
lora_int_id: int
|
lora_int_id: int
|
||||||
lora_local_path: str
|
lora_path: str = ""
|
||||||
|
lora_local_path: Optional[str] = field(default=None, repr=False)
|
||||||
long_lora_max_len: Optional[int] = None
|
long_lora_max_len: Optional[int] = None
|
||||||
__hash__ = AdapterRequest.__hash__
|
__hash__ = AdapterRequest.__hash__
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if 'lora_local_path' in self.__dict__:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'lora_local_path' attribute is deprecated "
|
||||||
|
"and will be removed in a future version. "
|
||||||
|
"Please use 'lora_path' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
if not self.lora_path:
|
||||||
|
self.lora_path = self.lora_local_path or ""
|
||||||
|
|
||||||
|
# Ensure lora_path is not empty
|
||||||
|
assert self.lora_path, "lora_path cannot be empty"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def adapter_id(self):
|
def adapter_id(self):
|
||||||
return self.lora_int_id
|
return self.lora_int_id
|
||||||
@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return self.lora_name
|
return self.lora_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def path(self):
|
||||||
|
return self.lora_path
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_path(self):
|
def local_path(self):
|
||||||
return self.lora_local_path
|
warnings.warn(
|
||||||
|
"The 'local_path' attribute is deprecated "
|
||||||
|
"and will be removed in a future version. "
|
||||||
|
"Please use 'path' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
return self.lora_path
|
||||||
|
|
||||||
|
@local_path.setter
|
||||||
|
def local_path(self, value):
|
||||||
|
warnings.warn(
|
||||||
|
"The 'local_path' attribute is deprecated "
|
||||||
|
"and will be removed in a future version. "
|
||||||
|
"Please use 'path' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2)
|
||||||
|
self.lora_path = value
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
|
import os
|
||||||
from typing import List, Optional, Set, Tuple, Type
|
from typing import List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
|
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||||
|
HFValidationError, RepositoryNotFoundError)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
|||||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||||
|
|
||||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||||
|
|
||||||
|
|
||||||
|
def get_adapter_absolute_path(lora_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolves the given lora_path to an absolute local path.
|
||||||
|
|
||||||
|
If the lora_path is identified as a Hugging Face model identifier,
|
||||||
|
it will download the model and return the local snapshot path.
|
||||||
|
Otherwise, it treats the lora_path as a local file path and
|
||||||
|
converts it to an absolute path.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
lora_path (str): The path to the lora model, which can be an absolute path,
|
||||||
|
a relative path, or a Hugging Face model identifier.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The resolved absolute local path to the lora model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if the path is an absolute path. Return it no matter exists or not.
|
||||||
|
if os.path.isabs(lora_path):
|
||||||
|
return lora_path
|
||||||
|
|
||||||
|
# If the path starts with ~, expand the user home directory.
|
||||||
|
if lora_path.startswith('~'):
|
||||||
|
return os.path.expanduser(lora_path)
|
||||||
|
|
||||||
|
# Check if the expanded relative path exists locally.
|
||||||
|
if os.path.exists(lora_path):
|
||||||
|
return os.path.abspath(lora_path)
|
||||||
|
|
||||||
|
# If the path does not exist locally, assume it's a Hugging Face repo.
|
||||||
|
try:
|
||||||
|
local_snapshot_path = huggingface_hub.snapshot_download(
|
||||||
|
repo_id=lora_path)
|
||||||
|
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
|
||||||
|
HFValidationError):
|
||||||
|
# Handle errors that may occur during the download
|
||||||
|
# Return original path instead instead of throwing error here
|
||||||
|
logger.exception("Error downloading the HuggingFace model")
|
||||||
|
return lora_path
|
||||||
|
|
||||||
|
return local_snapshot_path
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||||
LRUCacheLoRAModelManager, create_lora_manager)
|
LRUCacheLoRAModelManager, create_lora_manager)
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.utils import get_adapter_absolute_path
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
packed_modules_mapping[module])
|
packed_modules_mapping[module])
|
||||||
else:
|
else:
|
||||||
expected_lora_modules.append(module)
|
expected_lora_modules.append(module)
|
||||||
|
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||||
lora = self._lora_model_cls.from_local_checkpoint(
|
lora = self._lora_model_cls.from_local_checkpoint(
|
||||||
lora_request.lora_local_path,
|
lora_path,
|
||||||
expected_lora_modules,
|
expected_lora_modules,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
lora_model_id=lora_request.lora_int_id,
|
lora_model_id=lora_request.lora_int_id,
|
||||||
@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
|
|||||||
embedding_padding_modules=self.embedding_padding_modules,
|
embedding_padding_modules=self.embedding_padding_modules,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(
|
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
||||||
f"Loading lora {lora_request.lora_local_path} failed") from e
|
|
||||||
if lora.rank > self.lora_config.max_lora_rank:
|
if lora.rank > self.lora_config.max_lora_rank:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||||
|
|||||||
@ -137,14 +137,13 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args,
|
|||||||
if lora_request is None:
|
if lora_request is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
tokenizer = get_tokenizer(lora_request.lora_local_path, *args,
|
tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs)
|
||||||
**kwargs)
|
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
# No tokenizer was found in the LoRA folder,
|
# No tokenizer was found in the LoRA folder,
|
||||||
# use base model tokenizer
|
# use base model tokenizer
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No tokenizer found in %s, using base model tokenizer instead. "
|
"No tokenizer found in %s, using base model tokenizer instead. "
|
||||||
"(Exception: %s)", lora_request.lora_local_path, e)
|
"(Exception: %s)", lora_request.lora_path, e)
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|||||||
@ -691,7 +691,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
dummy_lora_request = LoRARequest(
|
dummy_lora_request = LoRARequest(
|
||||||
lora_name=f"warmup_{lora_id}",
|
lora_name=f"warmup_{lora_id}",
|
||||||
lora_int_id=lora_id,
|
lora_int_id=lora_id,
|
||||||
lora_local_path="/not/a/real/path",
|
lora_path="/not/a/real/path",
|
||||||
)
|
)
|
||||||
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
self.lora_manager.add_dummy_lora(dummy_lora_request,
|
||||||
rank=LORA_WARMUP_RANK)
|
rank=LORA_WARMUP_RANK)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user