[Frontend] enable passing multiple LoRA adapters at once to generate() (#5300)

This commit is contained in:
Matthew Goldey 2024-06-06 16:48:13 -04:00 committed by GitHub
parent abe855d637
commit 828da0d44e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 91 additions and 17 deletions

View File

@ -0,0 +1,69 @@
import weakref
import pytest
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
from vllm import LLM
from vllm.lora.request import LoRARequest
from ..conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
LORA_NAME = "typeof/zephyr-7b-beta-lora"
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
tensor_parallel_size=1,
max_model_len=8192,
enable_lora=True,
max_loras=4,
max_lora_rank=64,
max_num_seqs=128,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
@pytest.fixture(scope="session")
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.mark.skip_global_cleanup
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
lora_request = [
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
for idx in range(len(PROMPTS))
]
# Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(PROMPTS, lora_request=lora_request)
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
# Single LoRARequest should be applied to every prompt
single_lora_request = lora_request[0]
outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
assert len(PROMPTS) == len(outputs)

View File

@ -170,7 +170,7 @@ class LLM:
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -182,7 +182,7 @@ class LLM:
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -195,7 +195,7 @@ class LLM:
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -208,7 +208,7 @@ class LLM:
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -219,7 +219,7 @@ class LLM:
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -232,7 +232,7 @@ class LLM:
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
...
@ -249,7 +249,7 @@ class LLM:
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
@ -312,7 +312,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -324,7 +324,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -337,7 +337,7 @@ class LLM:
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -350,7 +350,7 @@ class LLM:
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -361,7 +361,7 @@ class LLM:
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -374,7 +374,7 @@ class LLM:
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
...
@ -391,7 +391,7 @@ class LLM:
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
@ -498,7 +498,7 @@ class LLM:
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
@ -509,20 +509,25 @@ class LLM:
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
list) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
)
def _add_request(
self,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,