multi-LoRA as extra models in OpenAI server (#2775)

how to serve the loras (mimicking the [multilora inference example](https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py)):
```terminal
$ export LORA_PATH=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
$ python -m vllm.entrypoints.api_server \
 --model meta-llama/Llama-2-7b-hf \
 --enable-lora \
 --lora-modules sql-lora=$LORA_PATH sql-lora2=$LORA_PATH
```
the above server will list 3 separate values if the user queries `/models`: one for the base served model, and one each for the specified lora modules. in this case sql-lora and sql-lora2 point to the same underlying lora, but this need not be the case. lora config values take the same values they do in EngineArgs

no work has been done here to scope client permissions to specific models
This commit is contained in:
jvmncs 2024-02-17 15:00:48 -05:00 committed by GitHub
parent 185b2c29e2
commit 8f36444c4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 200 additions and 27 deletions

View File

@ -49,4 +49,43 @@ the third parameter is the path to the LoRA adapter.
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
Serving LoRA Adapters
---------------------
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
.. code-block:: bash
python -m vllm.entrypoints.api_server \
--model meta-llama/Llama-2-7b-hf \
--enable-lora \
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
with its base model:
.. code-block:: bash
curl localhost:8000/v1/models | jq .
{
"object": "list",
"data": [
{
"id": "meta-llama/Llama-2-7b-hf",
"object": "model",
...
},
{
"id": "sql-lora",
"object": "model",
...
}
]
}
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).

View File

@ -12,7 +12,9 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
from vllm.lora.request import LoRARequest
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.
2 requests for base model, 4 requests for the LoRA. We define 2

View File

@ -7,9 +7,11 @@ import pytest
import requests
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
import openai # use the official client for correctness check
from huggingface_hub import snapshot_download # downloading lora to test lora requests
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
pytestmark = pytest.mark.asyncio
@ -54,7 +56,12 @@ class ServerRunner:
@pytest.fixture(scope="session")
def server():
def zephyr_lora_files():
return snapshot_download(repo_id=LORA_NAME)
@pytest.fixture(scope="session")
def server(zephyr_lora_files):
ray.init()
server_runner = ServerRunner.remote([
"--model",
@ -64,6 +71,17 @@ def server():
"--max-model-len",
"8192",
"--enforce-eager",
# lora config below
"--enable-lora",
"--lora-modules",
f"zephyr-lora={zephyr_lora_files}",
f"zephyr-lora2={zephyr_lora_files}",
"--max-lora-rank",
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128"
])
ray.get(server_runner.ready.remote())
yield server_runner
@ -79,8 +97,25 @@ def client():
yield client
async def test_single_completion(server, client: openai.AsyncOpenAI):
completion = await client.completions.create(model=MODEL_NAME,
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
served_model = models[0]
lora_models = models[1:]
assert served_model.id == MODEL_NAME
assert all(model.root == MODEL_NAME for model in models)
assert lora_models[0].id == "zephyr-lora"
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_single_completion(server, client: openai.AsyncOpenAI,
model_name: str):
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
completion.choices[0].text) >= 5
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
# test single completion
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
)
@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
assert message.content is not None and len(message.content) >= 0
async def test_completion_streaming(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_completion_streaming(server, client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is an LLM?"
single_completion = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
single_usage = single_completion.usage
stream = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == single_output
async def test_chat_streaming(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_chat_streaming(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
# test single completion
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
# test streaming
stream = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
temperature=0.0,
@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output
async def test_batch_completions(server, client: openai.AsyncOpenAI):
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_batch_completions(server, client: openai.AsyncOpenAI,
model_name: str):
# test simple list
batch = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test n = 2
batch = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test streaming
batch = await client.completions.create(
model=MODEL_NAME,
model=model_name,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,

View File

@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
TIMEOUT_KEEP_ALIVE = 5 # seconds
@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
@ -81,6 +92,15 @@ def parse_args():
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template",
type=str,
default=None,
@ -217,8 +237,10 @@ if __name__ == "__main__":
engine = AsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
# Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model)

View File

@ -1,7 +1,7 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Union
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__)
@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
engine: AsyncLLMEngine,
served_model: str,
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
super().__init__(engine=engine, served_model=served_model)
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids)
request_id, token_ids,
lora_request)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(

View File

@ -15,7 +15,7 @@ from .protocol import (
UsageInfo,
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
logger = init_logger(__name__)
@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str):
super().__init__(engine=engine, served_model=served_model)
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
generators = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
self.engine.generate(None,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids))
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
return self.create_error_response(str(e))

View File

@ -1,4 +1,5 @@
import asyncio
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ErrorResponse, LogProbs,
ModelCard, ModelList,
ModelPermission)
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
@dataclass
class LoRA:
name: str
local_path: str
class OpenAIServing:
def __init__(self, engine: AsyncLLMEngine, served_model: str):
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules=Optional[List[LoRA]]):
self.engine = engine
self.served_model = served_model
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0
self.tokenizer = None
@ -50,6 +71,13 @@ class OpenAIServing:
root=self.served_model,
permission=[ModelPermission()])
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model,
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
@ -99,11 +127,22 @@ class OpenAIServing:
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model:
return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],