multi-LoRA as extra models in OpenAI server (#2775)
how to serve the loras (mimicking the [multilora inference example](https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py)): ```terminal $ export LORA_PATH=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ $ python -m vllm.entrypoints.api_server \ --model meta-llama/Llama-2-7b-hf \ --enable-lora \ --lora-modules sql-lora=$LORA_PATH sql-lora2=$LORA_PATH ``` the above server will list 3 separate values if the user queries `/models`: one for the base served model, and one each for the specified lora modules. in this case sql-lora and sql-lora2 point to the same underlying lora, but this need not be the case. lora config values take the same values they do in EngineArgs no work has been done here to scope client permissions to specific models
This commit is contained in:
parent
185b2c29e2
commit
8f36444c4f
@ -49,4 +49,43 @@ the third parameter is the path to the LoRA adapter.
|
||||
|
||||
|
||||
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
|
||||
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
|
||||
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
|
||||
|
||||
Serving LoRA Adapters
|
||||
---------------------
|
||||
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
|
||||
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python -m vllm.entrypoints.api_server \
|
||||
--model meta-llama/Llama-2-7b-hf \
|
||||
--enable-lora \
|
||||
--lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
|
||||
|
||||
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
|
||||
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
|
||||
with its base model:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
curl localhost:8000/v1/models | jq .
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "meta-llama/Llama-2-7b-hf",
|
||||
"object": "model",
|
||||
...
|
||||
},
|
||||
{
|
||||
"id": "sql-lora",
|
||||
"object": "model",
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
|
||||
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
|
||||
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).
|
||||
|
||||
@ -12,7 +12,9 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
|
||||
def create_test_prompts(
|
||||
lora_path: str
|
||||
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
|
||||
"""Create a list of test prompts with their sampling parameters.
|
||||
|
||||
2 requests for base model, 4 requests for the LoRA. We define 2
|
||||
|
||||
@ -7,9 +7,11 @@ import pytest
|
||||
import requests
|
||||
import ray # using Ray for overall ease of process management, parallel requests, and debugging.
|
||||
import openai # use the official client for correctness check
|
||||
from huggingface_hub import snapshot_download # downloading lora to test lora requests
|
||||
|
||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here
|
||||
LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
@ -54,7 +56,12 @@ class ServerRunner:
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server():
|
||||
def zephyr_lora_files():
|
||||
return snapshot_download(repo_id=LORA_NAME)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def server(zephyr_lora_files):
|
||||
ray.init()
|
||||
server_runner = ServerRunner.remote([
|
||||
"--model",
|
||||
@ -64,6 +71,17 @@ def server():
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
f"zephyr-lora={zephyr_lora_files}",
|
||||
f"zephyr-lora2={zephyr_lora_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--max-cpu-loras",
|
||||
"2",
|
||||
"--max-num-seqs",
|
||||
"128"
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@ -79,8 +97,25 @@ def client():
|
||||
yield client
|
||||
|
||||
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||
completion = await client.completions.create(model=MODEL_NAME,
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
lora_models = models[1:]
|
||||
assert served_model.id == MODEL_NAME
|
||||
assert all(model.root == MODEL_NAME for model in models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
)
|
||||
@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
async def test_completion_streaming(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_completion_streaming(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
|
||||
single_usage = single_completion.usage
|
||||
|
||||
stream = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI):
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_chat_streaming(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
|
||||
# test streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
async def test_batch_completions(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI):
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
model=model_name,
|
||||
prompt=["Hello, my name is", "Hello, my name is"],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
|
||||
@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe
|
||||
from vllm.logger import init_logger
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_engine import LoRA
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI):
|
||||
app = fastapi.FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
lora_list = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRA(name, path))
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
@ -81,6 +92,15 @@ def parse_args():
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help=
|
||||
"LoRA module configurations in the format name=path. Multiple modules can be specified."
|
||||
)
|
||||
parser.add_argument("--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
@ -217,8 +237,10 @@ if __name__ == "__main__":
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
||||
args.response_role,
|
||||
args.lora_modules,
|
||||
args.chat_template)
|
||||
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine, served_model, args.lora_modules)
|
||||
|
||||
# Register labels for metrics
|
||||
add_global_metrics_labels(model_name=engine_args.model)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import time
|
||||
import codecs
|
||||
from fastapi import Request
|
||||
from typing import AsyncGenerator, AsyncIterator, Union
|
||||
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
response_role: str,
|
||||
lora_modules: Optional[List[LoRA]] = None,
|
||||
chat_template=None):
|
||||
super().__init__(engine=engine, served_model=served_model)
|
||||
super().__init__(engine=engine,
|
||||
served_model=served_model,
|
||||
lora_modules=lora_modules)
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
|
||||
@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
token_ids = self._validate_prompt_and_tokenize(request,
|
||||
prompt=prompt)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids)
|
||||
request_id, token_ids,
|
||||
lora_request)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
|
||||
@ -15,7 +15,7 @@ from .protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -249,8 +249,13 @@ def merge_async_iterators(*iterators):
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
||||
super().__init__(engine=engine, served_model=served_model)
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
lora_modules: Optional[List[LoRA]] = None):
|
||||
super().__init__(engine=engine,
|
||||
served_model=served_model,
|
||||
lora_modules=lora_modules)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
generators = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
self.engine.generate(None,
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=input_ids))
|
||||
prompt_token_ids=input_ids,
|
||||
lora_request=lora_request))
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
from vllm.logger import init_logger
|
||||
@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
ErrorResponse, LogProbs,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRA:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
lora_modules=Optional[List[LoRA]]):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
if lora_modules is None:
|
||||
self.lora_requests = []
|
||||
else:
|
||||
self.lora_requests = [
|
||||
LoRARequest(
|
||||
lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_local_path=lora.local_path,
|
||||
) for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.max_model_len = 0
|
||||
self.tokenizer = None
|
||||
@ -50,6 +71,13 @@ class OpenAIServing:
|
||||
root=self.served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=self.served_model,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def _create_logprobs(
|
||||
@ -99,11 +127,22 @@ class OpenAIServing:
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user