[Frontend] enable passing multiple LoRA adapters at once to generate() (#5300)
This commit is contained in:
parent
abe855d637
commit
828da0d44e
69
tests/entrypoints/test_llm_generate_multiple_loras.py
Normal file
69
tests/entrypoints/test_llm_generate_multiple_loras.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import weakref
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
# downloading lora to test lora requests
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
from ..conftest import cleanup
|
||||||
|
|
||||||
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.llm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def llm():
|
||||||
|
# pytest caches the fixture so we use weakref.proxy to
|
||||||
|
# enable garbage collection
|
||||||
|
llm = LLM(model=MODEL_NAME,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
max_model_len=8192,
|
||||||
|
enable_lora=True,
|
||||||
|
max_loras=4,
|
||||||
|
max_lora_rank=64,
|
||||||
|
max_num_seqs=128,
|
||||||
|
enforce_eager=True)
|
||||||
|
|
||||||
|
with llm.deprecate_legacy_api():
|
||||||
|
yield weakref.proxy(llm)
|
||||||
|
|
||||||
|
del llm
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def zephyr_lora_files():
|
||||||
|
return snapshot_download(repo_id=LORA_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
|
||||||
|
lora_request = [
|
||||||
|
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
|
||||||
|
for idx in range(len(PROMPTS))
|
||||||
|
]
|
||||||
|
# Multiple SamplingParams should be matched with each prompt
|
||||||
|
outputs = llm.generate(PROMPTS, lora_request=lora_request)
|
||||||
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
|
# Exception raised, if the size of params does not match the size of prompts
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
|
||||||
|
|
||||||
|
# Single LoRARequest should be applied to every prompt
|
||||||
|
single_lora_request = lora_request[0]
|
||||||
|
outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
|
||||||
|
assert len(PROMPTS) == len(outputs)
|
||||||
@ -170,7 +170,7 @@ class LLM:
|
|||||||
List[SamplingParams]]] = None,
|
List[SamplingParams]]] = None,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ class LLM:
|
|||||||
List[SamplingParams]]] = None,
|
List[SamplingParams]]] = None,
|
||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -195,7 +195,7 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
prompt_token_ids: List[List[int]],
|
prompt_token_ids: List[List[int]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -219,7 +219,7 @@ class LLM:
|
|||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -232,7 +232,7 @@ class LLM:
|
|||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ class LLM:
|
|||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
"""Generates the completions for the input prompts.
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
@ -312,7 +312,7 @@ class LLM:
|
|||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
prompt_token_ids: Optional[List[int]] = None,
|
prompt_token_ids: Optional[List[int]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -324,7 +324,7 @@ class LLM:
|
|||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -337,7 +337,7 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -350,7 +350,7 @@ class LLM:
|
|||||||
*,
|
*,
|
||||||
prompt_token_ids: List[List[int]],
|
prompt_token_ids: List[List[int]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -361,7 +361,7 @@ class LLM:
|
|||||||
pooling_params: None,
|
pooling_params: None,
|
||||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -374,7 +374,7 @@ class LLM:
|
|||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
...
|
...
|
||||||
|
|
||||||
@ -391,7 +391,7 @@ class LLM:
|
|||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||||
use_tqdm: bool = True,
|
use_tqdm: bool = True,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> List[EmbeddingRequestOutput]:
|
) -> List[EmbeddingRequestOutput]:
|
||||||
"""Generates the completions for the input prompts.
|
"""Generates the completions for the input prompts.
|
||||||
|
|
||||||
@ -498,7 +498,7 @@ class LLM:
|
|||||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||||
Sequence[PoolingParams]],
|
Sequence[PoolingParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(inputs, (str, dict)):
|
if isinstance(inputs, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
@ -509,20 +509,25 @@ class LLM:
|
|||||||
if isinstance(params, list) and len(params) != num_requests:
|
if isinstance(params, list) and len(params) != num_requests:
|
||||||
raise ValueError("The lengths of prompts and params "
|
raise ValueError("The lengths of prompts and params "
|
||||||
"must be the same.")
|
"must be the same.")
|
||||||
|
if isinstance(lora_request,
|
||||||
|
list) and len(lora_request) != num_requests:
|
||||||
|
raise ValueError("The lengths of prompts and lora_request "
|
||||||
|
"must be the same.")
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
for i, request_inputs in enumerate(inputs):
|
for i, request_inputs in enumerate(inputs):
|
||||||
self._add_request(
|
self._add_request(
|
||||||
request_inputs,
|
request_inputs,
|
||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request[i] if isinstance(
|
||||||
|
lora_request, Sequence) else lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(request_id,
|
self.llm_engine.add_request(request_id,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user