[Frontend] enable passing multiple LoRA adapters at once to generate() (#5300)
This commit is contained in:
parent
abe855d637
commit
828da0d44e
69
tests/entrypoints/test_llm_generate_multiple_loras.py
Normal file
69
tests/entrypoints/test_llm_generate_multiple_loras.py
Normal file
@ -0,0 +1,69 @@
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
from ..conftest import cleanup
|
||||
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||
|
||||
pytestmark = pytest.mark.llm
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=8192,
|
||||
enable_lora=True,
|
||||
max_loras=4,
|
||||
max_lora_rank=64,
|
||||
max_num_seqs=128,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
|
||||
lora_request = [
|
||||
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
|
||||
for idx in range(len(PROMPTS))
|
||||
]
|
||||
# Multiple SamplingParams should be matched with each prompt
|
||||
outputs = llm.generate(PROMPTS, lora_request=lora_request)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# Exception raised, if the size of params does not match the size of prompts
|
||||
with pytest.raises(ValueError):
|
||||
outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
|
||||
|
||||
# Single LoRARequest should be applied to every prompt
|
||||
single_lora_request = lora_request[0]
|
||||
outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
@ -170,7 +170,7 @@ class LLM:
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -182,7 +182,7 @@ class LLM:
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -195,7 +195,7 @@ class LLM:
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -208,7 +208,7 @@ class LLM:
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -219,7 +219,7 @@ class LLM:
|
||||
sampling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -232,7 +232,7 @@ class LLM:
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@ -249,7 +249,7 @@ class LLM:
|
||||
Sequence[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -312,7 +312,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -324,7 +324,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -337,7 +337,7 @@ class LLM:
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -350,7 +350,7 @@ class LLM:
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -361,7 +361,7 @@ class LLM:
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -374,7 +374,7 @@ class LLM:
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@ -391,7 +391,7 @@ class LLM:
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -498,7 +498,7 @@ class LLM:
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
) -> None:
|
||||
if isinstance(inputs, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
@ -509,20 +509,25 @@ class LLM:
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if isinstance(lora_request,
|
||||
list) and len(lora_request) != num_requests:
|
||||
raise ValueError("The lengths of prompts and lora_request "
|
||||
"must be the same.")
|
||||
|
||||
# Add requests to the engine.
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
self._add_request(
|
||||
request_inputs,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request,
|
||||
lora_request=lora_request[i] if isinstance(
|
||||
lora_request, Sequence) else lora_request,
|
||||
)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user