diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index b773edfc..1910f265 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -49,4 +49,43 @@ the third parameter is the path to the LoRA adapter. Check out `examples/multilora_inference.py `_ -for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. \ No newline at end of file +for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options. + +Serving LoRA Adapters +--------------------- +LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use +``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server: + +.. code-block:: bash + + python -m vllm.entrypoints.api_server \ + --model meta-llama/Llama-2-7b-hf \ + --enable-lora \ + --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/ + +The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``, +etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along +with its base model: + +.. code-block:: bash + + curl localhost:8000/v1/models | jq . + { + "object": "list", + "data": [ + { + "id": "meta-llama/Llama-2-7b-hf", + "object": "model", + ... + }, + { + "id": "sql-lora", + "object": "model", + ... + } + ] + } + +Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be +processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other +LoRA adapter requests if they were provided and ``max_loras`` is set high enough). diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 8fdd243a..cd445148 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -12,7 +12,9 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput from vllm.lora.request import LoRARequest -def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: +def create_test_prompts( + lora_path: str +) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]: """Create a list of test prompts with their sampling parameters. 2 requests for base model, 4 requests for the LoRA. We define 2 diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 54522f0a..3a359502 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -7,9 +7,11 @@ import pytest import requests import ray # using Ray for overall ease of process management, parallel requests, and debugging. import openai # use the official client for correctness check +from huggingface_hub import snapshot_download # downloading lora to test lora requests MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here +LORA_NAME = "typeof/zephyr-7b-beta-lora" # technically this needs Mistral-7B-v0.1 as base, but we're not testing generation quality here pytestmark = pytest.mark.asyncio @@ -54,7 +56,12 @@ class ServerRunner: @pytest.fixture(scope="session") -def server(): +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="session") +def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ "--model", @@ -64,6 +71,17 @@ def server(): "--max-model-len", "8192", "--enforce-eager", + # lora config below + "--enable-lora", + "--lora-modules", + f"zephyr-lora={zephyr_lora_files}", + f"zephyr-lora2={zephyr_lora_files}", + "--max-lora-rank", + "64", + "--max-cpu-loras", + "2", + "--max-num-seqs", + "128" ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -79,8 +97,25 @@ def client(): yield client -async def test_single_completion(server, client: openai.AsyncOpenAI): - completion = await client.completions.create(model=MODEL_NAME, +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + lora_models = models[1:] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + assert lora_models[0].id == "zephyr-lora" + assert lora_models[1].id == "zephyr-lora2" + + +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_single_completion(server, client: openai.AsyncOpenAI, + model_name: str): + completion = await client.completions.create(model=model_name, prompt="Hello, my name is", max_tokens=5, temperature=0.0) @@ -104,7 +139,13 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): completion.choices[0].text) >= 5 -async def test_single_chat_session(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_single_chat_session(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -115,7 +156,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): # test single completion chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, ) @@ -139,11 +180,17 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): assert message.content is not None and len(message.content) >= 0 -async def test_completion_streaming(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_completion_streaming(server, client: openai.AsyncOpenAI, + model_name: str): prompt = "What is an LLM?" single_completion = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, @@ -152,7 +199,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI): single_usage = single_completion.usage stream = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=prompt, max_tokens=5, temperature=0.0, @@ -166,7 +213,13 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI): assert "".join(chunks) == single_output -async def test_chat_streaming(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_chat_streaming(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -177,7 +230,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): # test single completion chat_completion = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, temperature=0.0, @@ -187,7 +240,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): # test streaming stream = await client.chat.completions.create( - model=MODEL_NAME, + model=model_name, messages=messages, max_tokens=10, temperature=0.0, @@ -204,10 +257,16 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): assert "".join(chunks) == output -async def test_batch_completions(server, client: openai.AsyncOpenAI): +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_batch_completions(server, client: openai.AsyncOpenAI, + model_name: str): # test simple list batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], max_tokens=5, temperature=0.0, @@ -217,7 +276,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI): # test n = 2 batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], n=2, max_tokens=5, @@ -236,7 +295,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI): # test streaming batch = await client.completions.create( - model=MODEL_NAME, + model=model_name, prompt=["Hello, my name is", "Hello, my name is"], max_tokens=5, temperature=0.0, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index deb0fddd..a2176054 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -23,6 +23,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRe from vllm.logger import init_logger from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import LoRA TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -48,6 +49,16 @@ async def lifespan(app: fastapi.FastAPI): app = fastapi.FastAPI(lifespan=lifespan) +class LoRAParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + lora_list = [] + for item in values: + name, path = item.split('=') + lora_list.append(LoRA(name, path)) + setattr(namespace, self.dest, lora_list) + + def parse_args(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") @@ -81,6 +92,15 @@ def parse_args(): help="The model name used in the API. If not " "specified, the model name will be the same as " "the huggingface name.") + parser.add_argument( + "--lora-modules", + type=str, + default=None, + nargs='+', + action=LoRAParserAction, + help= + "LoRA module configurations in the format name=path. Multiple modules can be specified." + ) parser.add_argument("--chat-template", type=str, default=None, @@ -217,8 +237,10 @@ if __name__ == "__main__": engine = AsyncLLMEngine.from_engine_args(engine_args) openai_serving_chat = OpenAIServingChat(engine, served_model, args.response_role, + args.lora_modules, args.chat_template) - openai_serving_completion = OpenAIServingCompletion(engine, served_model) + openai_serving_completion = OpenAIServingCompletion( + engine, served_model, args.lora_modules) # Register labels for metrics add_global_metrics_labels(model_name=engine_args.model) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a9e4c355..850797ae 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,7 +1,7 @@ import time import codecs from fastapi import Request -from typing import AsyncGenerator, AsyncIterator, Union +from typing import AsyncGenerator, AsyncIterator, Optional, List, Union from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -11,7 +11,7 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) from vllm.outputs import RequestOutput -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA logger = init_logger(__name__) @@ -22,8 +22,11 @@ class OpenAIServingChat(OpenAIServing): engine: AsyncLLMEngine, served_model: str, response_role: str, + lora_modules: Optional[List[LoRA]] = None, chat_template=None): - super().__init__(engine=engine, served_model=served_model) + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) @@ -64,11 +67,13 @@ class OpenAIServingChat(OpenAIServing): token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) except ValueError as e: return self.create_error_response(str(e)) result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids) + request_id, token_ids, + lora_request) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 191142d2..667b659f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -15,7 +15,7 @@ from .protocol import ( UsageInfo, ) from vllm.outputs import RequestOutput -from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA logger = init_logger(__name__) @@ -249,8 +249,13 @@ def merge_async_iterators(*iterators): class OpenAIServingCompletion(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, served_model: str): - super().__init__(engine=engine, served_model=served_model) + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules: Optional[List[LoRA]] = None): + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -284,6 +289,7 @@ class OpenAIServingCompletion(OpenAIServing): generators = [] try: sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): @@ -298,7 +304,8 @@ class OpenAIServingCompletion(OpenAIServing): self.engine.generate(None, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids)) + prompt_token_ids=input_ids, + lora_request=lora_request)) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 390f9aeb..09945471 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,4 +1,5 @@ import asyncio +from dataclasses import dataclass from http import HTTPStatus from typing import Dict, List, Optional, Union from vllm.logger import init_logger @@ -9,15 +10,35 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) +from vllm.lora.request import LoRARequest logger = init_logger(__name__) +@dataclass +class LoRA: + name: str + local_path: str + + class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, served_model: str): + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules=Optional[List[LoRA]]): self.engine = engine self.served_model = served_model + if lora_modules is None: + self.lora_requests = [] + else: + self.lora_requests = [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_local_path=lora.local_path, + ) for i, lora in enumerate(lora_modules, start=1) + ] self.max_model_len = 0 self.tokenizer = None @@ -50,6 +71,13 @@ class OpenAIServing: root=self.served_model, permission=[ModelPermission()]) ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=self.served_model, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + model_cards.extend(lora_cards) return ModelList(data=model_cards) def _create_logprobs( @@ -99,11 +127,22 @@ class OpenAIServing: async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model == self.served_model: return + if request.model in [lora.lora_name for lora in self.lora_requests]: + return return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) + def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + if request.model == self.served_model: + return + for lora in self.lora_requests: + if request.model == lora.lora_name: + return lora + # if _check_model has been called earlier, this will be unreachable + raise ValueError("The model `{request.model}` does not exist.") + def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest],