From 8f36444c4f9a55669bcb64e20b5588c0dd72bd93 Mon Sep 17 00:00:00 2001 From: jvmncs Date: Sat, 17 Feb 2024 15:00:48 -0500 Subject: [PATCH] multi-LoRA as extra models in OpenAI server (#2775) how to serve the loras (mimicking the [multilora inference example](https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py)): ```terminal $ export LORA_PATH=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ $ python -m vllm.entrypoints.api_server \ --model meta-llama/Llama-2-7b-hf \ --enable-lora \ --lora-modules sql-lora=$LORA_PATH sql-lora2=$LORA_PATH ``` the above server will list 3 separate values if the user queries `/models`: one for the base served model, and one each for the specified lora modules. in this case sql-lora and sql-lora2 point to the same underlying lora, but this need not be the case. lora config values take the same values they do in EngineArgs no work has been done here to scope client permissions to specific models --- docs/source/models/lora.rst | 41 ++++++++- examples/multilora_inference.py | 4 +- tests/entrypoints/test_openai_server.py | 89 +++++++++++++++---- vllm/entrypoints/openai/api_server.py | 24 ++++- vllm/entrypoints/openai/serving_chat.py | 13 ++- vllm/entrypoints/openai/serving_completion.py | 15 +++- vllm/entrypoints/openai/serving_engine.py | 41 ++++++++- 7 files changed, 200 insertions(+), 27 deletions(-) diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b773edfc..1910f265 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -49,4 +49,43 @@ the third parameter is the path to the LoRA adapter. Check out `examples/multilora_inference.py `_ -for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. \ No newline at end of file +for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. + +Serving LoRA Adapters +--------------------- +LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use +``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server: + +.. code-block:: bash + + python -m vllm.entrypoints.api_server \ + --model meta-llama/Llama-2-7b-hf \ + --enable-lora \ + --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ + +The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``, +etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along +with its base model: + +.. code-block:: bash + + curl localhost:8000/v1/models | jq . + { + "object": "list", + "data": [ + { + "id": "meta-llama/Llama-2-7b-hf", + "object": "model", + ... + }, + { + "id": "sql-lora", + "object": "model", + ... + } + ] + } + +Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be +processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other +LoRA adapter requests if they were provided and ``max_loras`` is set high enough). diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 8fdd243a..cd445148 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -12,7 +12,9 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput from vllm.lora.request import LoRARequest -def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: """Create a list of test prompts with their sampling parameters. 2 requests for base model, 4 requests for the LoRA. We define 2 diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 54522f0a..3a359502 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -7,9 +7,11 @@ import pytest import requests import ray # using Ray for overall ease of process management, parallel requests, and debugging. import openai # use the official client for correctness check +from huggingface_hub import snapshot_download # downloading lora to test lora requests MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here +LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here pytestmark = pytest.mark.asyncio @@ -54,7 +56,12 @@ class ServerRunner: @pytest.fixture(scope="session") -def server(): +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="session") +def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ "--model", @@ -64,6 +71,17 @@ def server(): "--max-model-len", "8192", "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "128" ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -79,8 +97,25 @@ def client(): yield client -async def test_single_completion(server, client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + lora_models = models[1:] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + assert lora_models[0].id == "zephyr-lora" + assert lora_models[1].id == "zephyr-lora2" + + +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_single_completion(server, client: openai.AsyncOpenAI, + model_name: str): + completion = await client.completions.create(model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0) @@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): completion.choices[0].text) >= 5 -async def test_single_chat_session(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_single_chat_session(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): # test single completion chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, ) @@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): assert message.content is not None and len(message.content) >= 0 -async def test_completion_streaming(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_completion_streaming(server, client: openai.AsyncOpenAI, + model_name: str): prompt = "What is an LLM?" single_completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, @@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI): single_usage = single_completion.usage stream = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, @@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI): assert "".join(chunks) == single_output -async def test_chat_streaming(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_chat_streaming(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): # test single completion chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, temperature=0.0, @@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): # test streaming stream = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, temperature=0.0, @@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): assert "".join(chunks) == output -async def test_batch_completions(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_batch_completions(server, client: openai.AsyncOpenAI, + model_name: str): # test simple list batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], max_tokens=5, temperature=0.0, @@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI): # test n = 2 batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], n=2, max_tokens=5, @@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI): # test streaming batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], max_tokens=5, temperature=0.0, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index deb0fddd..a2176054 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe from vllm.logger import init_logger from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import LoRA TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI): app = fastapi.FastAPI(lifespan=lifespan) +class LoRAParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + lora_list = [] + for item in values: + name, path = item.split('=') + lora_list.append(LoRA(name, path)) + setattr(namespace, self.dest, lora_list) + + def parse_args(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -81,6 +92,15 @@ def parse_args(): help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") + parser.add_argument( + "--lora-modules", + type=str, + default=None, + nargs='+', + action=LoRAParserAction, + help= + "LoRA module configurations in the format name=path. Multiple modules can be specified." + ) parser.add_argument("--chat-template", type=str, default=None, @@ -217,8 +237,10 @@ if __name__ == "__main__": engine = AsyncLLMEngine.from_engine_args(engine_args) openai_serving_chat = OpenAIServingChat(engine, served_model, args.response_role, + args.lora_modules, args.chat_template) - openai_serving_completion = OpenAIServingCompletion(engine, served_model) + openai_serving_completion = OpenAIServingCompletion( + engine, served_model, args.lora_modules) # Register labels for metrics add_global_metrics_labels(model_name=engine_args.model) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a9e4c355..850797ae 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,7 +1,7 @@ import time import codecs from fastapi import Request -from typing import AsyncGenerator, AsyncIterator, Union +from typing import AsyncGenerator, AsyncIterator, Optional, List, Union from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) from vllm.outputs import RequestOutput -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA logger = init_logger(__name__) @@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing): engine: AsyncLLMEngine, served_model: str, response_role: str, + lora_modules: Optional[List[LoRA]] = None, chat_template=None): - super().__init__(engine=engine, served_model=served_model) + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) @@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing): token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) except ValueError as e: return self.create_error_response(str(e)) result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids) + request_id, token_ids, + lora_request) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 191142d2..667b659f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -15,7 +15,7 @@ from .protocol import ( UsageInfo, ) from vllm.outputs import RequestOutput -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA logger = init_logger(__name__) @@ -249,8 +249,13 @@ def merge_async_iterators(*iterators): class OpenAIServingCompletion(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, served_model: str): - super().__init__(engine=engine, served_model=served_model) + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules: Optional[List[LoRA]] = None): + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing): generators = [] try: sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): @@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): self.engine.generate(None, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids)) + prompt_token_ids=input_ids, + lora_request=lora_request)) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 390f9aeb..09945471 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union from vllm.logger import init_logger @@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) +from vllm.lora.request import LoRARequest logger = init_logger(__name__) +@dataclass +class LoRA: + name: str + local_path: str + + class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, served_model: str): + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules=Optional[List[LoRA]]): self.engine = engine self.served_model = served_model + if lora_modules is None: + self.lora_requests = [] + else: + self.lora_requests = [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_local_path=lora.local_path, + ) for i, lora in enumerate(lora_modules, start=1) + ] self.max_model_len = 0 self.tokenizer = None @@ -50,6 +71,13 @@ class OpenAIServing: root=self.served_model, permission=[ModelPermission()]) ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=self.served_model, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + model_cards.extend(lora_cards) return ModelList(data=model_cards) def _create_logprobs( @@ -99,11 +127,22 @@ class OpenAIServing: async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model == self.served_model: return + if request.model in [lora.lora_name for lora in self.lora_requests]: + return return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) + def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + if request.model == self.served_model: + return + for lora in self.lora_requests: + if request.model == lora.lora_name: + return lora + # if _check_model has been called earlier, this will be unreachable + raise ValueError("The model `{request.model}` does not exist.") + def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest],