From 3a7dd7e367277c47472912e84375fa912df07328 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 24 Jan 2024 17:11:07 -0800 Subject: [PATCH] Support Batch Completion in Server (#2529) --- tests/entrypoints/test_openai_server.py | 55 +++- vllm/entrypoints/openai/serving_completion.py | 263 +++++++++++------- 2 files changed, 214 insertions(+), 104 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 3cef6bfd..54522f0a 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -1,5 +1,6 @@ -import time +import os import subprocess +import time import sys import pytest @@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio class ServerRunner: def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" self.proc = subprocess.Popen( ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, stdout=sys.stdout, stderr=sys.stderr, ) @@ -58,7 +62,8 @@ def server(): "--dtype", "bfloat16", # use half precision for speed and memory savings in CI environment "--max-model-len", - "8192" + "8192", + "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner @@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI): assert "".join(chunks) == output +async def test_batch_completions(server, client: openai.AsyncOpenAI): + # test simple list + batch = await client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=MODEL_NAME, + prompt=["Hello, my name is", "Hello, my name is"], + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7eaa7de4..8c9a7ad3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,6 +1,7 @@ +import asyncio import time from fastapi import Request -from typing import AsyncGenerator, AsyncIterator +from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing logger = init_logger(__name__) +TypeTokenIDs = list[int] +TypeTopLogProbs = List[Optional[dict[int, float]]] +TypeCreateLogProbsFn = Callable[ + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + async def completion_stream_generator( - request: CompletionRequest, - result_generator: AsyncIterator[RequestOutput], - echo_without_generation, create_logprobs_fn, request_id, created_time, - model_name) -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - has_echoed = [False] * request.n + request: CompletionRequest, + raw_request: Request, + on_abort, + result_generator: AsyncIterator[tuple[int, RequestOutput]], + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, +) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n * num_prompts + previous_num_tokens = [0] * request.n * num_prompts + has_echoed = [False] * request.n * num_prompts + + async for prompt_idx, res in result_generator: + + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await on_abort(f"{request_id}-{prompt_idx}") + raise StopAsyncIteration() - async for res in result_generator: - # TODO: handle client disconnect for streaming for output in res.outputs: - i = output.index - delta_text = output.text[len(previous_texts[i]):] - token_ids = output.token_ids[previous_num_tokens[i]:] - if request.logprobs is not None: - top_logprobs = output.logprobs[previous_num_tokens[i]:] - else: - top_logprobs = None - offsets = len(previous_texts[i]) - if request.echo and not has_echoed[i]: - if not echo_without_generation: - delta_text = res.prompt + delta_text - token_ids = res.prompt_token_ids + token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs + top_logprobs - else: # only just return the prompt - delta_text = res.prompt - token_ids = res.prompt_token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs + i = output.index + prompt_idx * request.n + # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. + + if request.echo and request.max_tokens == 0: + # only return the prompt + delta_text = res.prompt + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs has_echoed[i] = True + elif request.echo and request.max_tokens > 0 and not has_echoed[i]: + # echo the prompt and first token + delta_text = res.prompt + output.text + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs or []) + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text[len(previous_texts[i]):] + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + if request.logprobs is not None: + assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" logprobs = create_logprobs_fn( - token_ids=token_ids, + token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, - initial_text_offset=offsets, + initial_text_offset=len(previous_texts[i]), ) else: logprobs = None + previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) finish_reason = output.finish_reason @@ -77,7 +98,7 @@ async def completion_stream_generator( ]).model_dump_json(exclude_unset=True) yield f"data: {response_json}\n\n" - if output.finish_reason is not None: + if output.finish_reason is not None: # return final usage logprobs = LogProbs() if request.logprobs is not None else None prompt_tokens = len(res.prompt_token_ids) completion_tokens = len(output.token_ids) @@ -129,51 +150,58 @@ def parse_prompt_format(prompt) -> tuple[bool, list]: return prompt_is_tokens, prompts -def request_output_to_completion_response(final_res: RequestOutput, request, - echo_without_generation, - create_logprobs_fn, request_id, - created_time, - model_name) -> CompletionResponse: - assert final_res is not None +def request_output_to_completion_response( + final_res_batch: list[RequestOutput], + request: CompletionRequest, + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, +) -> CompletionResponse: choices = [] - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - for output in final_res.outputs: - if request.logprobs is not None: - if not echo_without_generation: - token_ids = output.token_ids - top_logprobs = output.logprobs - if request.echo: - token_ids = prompt_token_ids + token_ids - top_logprobs = prompt_logprobs + top_logprobs - else: + num_prompt_tokens = 0 + num_generated_tokens = 0 + for final_res in final_res_batch: + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + for output in final_res.outputs: + if request.echo and request.max_tokens == 0: token_ids = prompt_token_ids top_logprobs = prompt_logprobs - logprobs = create_logprobs_fn( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - if not echo_without_generation: - output_text = output.text - if request.echo: - output_text = prompt_text + output_text - else: - output_text = prompt_text - choice_data = CompletionResponseChoice( - index=output.index, - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens += len(prompt_token_ids) + num_generated_tokens += sum( + len(output.token_ids) for output in final_res.outputs) - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, @@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request, ) +def merge_async_iterators(*iterators): + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i, iterator): + async for item in iterator: + await queue.put((i, item)) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + yield item + await asyncio.gather(*_tasks) + + return consumer() + + class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str): @@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing): if error_check_ret is not None: return error_check_ret - # OpenAI API supports echoing the prompt when max_tokens is 0. - echo_without_generation = request.echo and request.max_tokens == 0 - # Return error for unsupported features. if request.suffix is not None: return self.create_error_response( @@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing): created_time = int(time.monotonic()) # Schedule the request and get the result generator. + generators = [] try: sampling_params = request.to_sampling_params() - prompt_is_tokens, prompts = parse_prompt_format(request.prompt) - if len(prompts) > 1: - raise ValueError( - "Batching in completion API is not supported.") - prompt = prompts[0] + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + input_ids = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + input_ids = self._validate_prompt_and_tokenize( + request, prompt=prompt) - if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( - request, prompt_ids=prompt) - else: - input_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) - - result_generator = self.engine.generate(None, - sampling_params, - request_id, - prompt_token_ids=input_ids) + generators.append( + self.engine.generate(None, + sampling_params, + f"{request_id}-{i}", + prompt_token_ids=input_ids)) except ValueError as e: return self.create_error_response(str(e)) + result_generator: AsyncIterator[tuple[ + int, RequestOutput]] = merge_async_iterators(*generators) + # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. stream = (request.stream @@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing): # Streaming response if stream: - return completion_stream_generator(request, result_generator, - echo_without_generation, + return completion_stream_generator(request, + raw_request, + self.engine.abort, + result_generator, self._create_logprobs, - request_id, created_time, - model_name) + request_id, + created_time, + model_name, + num_prompts=len(prompts)) # Non-streaming response - final_res: RequestOutput = None - async for res in result_generator: + final_res_batch: RequestOutput = [None] * len(prompts) + async for i, res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await self.engine.abort(request_id) + await self.engine.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") - final_res = res + final_res_batch[i] = res response = request_output_to_completion_response( - final_res, request, echo_without_generation, self._create_logprobs, - request_id, created_time, model_name) + final_res_batch, request, self._create_logprobs, request_id, + created_time, model_name) # When user requests streaming but we don't stream, we still need to # return a streaming response with a single event.