Support Batch Completion in Server (#2529)

This commit is contained in:
Simon Mo 2024-01-24 17:11:07 -08:00 committed by GitHub
parent 223c19224b
commit 3a7dd7e367
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 214 additions and 104 deletions

View File

@ -1,5 +1,6 @@
import time import os
import subprocess import subprocess
import time
import sys import sys
import pytest import pytest
@ -17,8 +18,11 @@ pytestmark = pytest.mark.asyncio
class ServerRunner: class ServerRunner:
def __init__(self, args): def __init__(self, args):
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen( self.proc = subprocess.Popen(
["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args,
env=env,
stdout=sys.stdout, stdout=sys.stdout,
stderr=sys.stderr, stderr=sys.stderr,
) )
@ -58,7 +62,8 @@ def server():
"--dtype", "--dtype",
"bfloat16", # use half precision for speed and memory savings in CI environment "bfloat16", # use half precision for speed and memory savings in CI environment
"--max-model-len", "--max-model-len",
"8192" "8192",
"--enforce-eager",
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner
@ -199,5 +204,51 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI):
assert "".join(chunks) == output assert "".join(chunks) == output
async def test_batch_completions(server, client: openai.AsyncOpenAI):
# test simple list
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
if __name__ == "__main__": if __name__ == "__main__":
pytest.main([__file__]) pytest.main([__file__])

View File

@ -1,6 +1,7 @@
import asyncio
import time import time
from fastapi import Request from fastapi import Request
from typing import AsyncGenerator, AsyncIterator from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -18,48 +19,68 @@ from vllm.entrypoints.openai.serving_engine import OpenAIServing
logger = init_logger(__name__) logger = init_logger(__name__)
TypeTokenIDs = list[int]
TypeTopLogProbs = List[Optional[dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator( async def completion_stream_generator(
request: CompletionRequest, request: CompletionRequest,
result_generator: AsyncIterator[RequestOutput], raw_request: Request,
echo_without_generation, create_logprobs_fn, request_id, created_time, on_abort,
model_name) -> AsyncGenerator[str, None]: result_generator: AsyncIterator[tuple[int, RequestOutput]],
previous_texts = [""] * request.n create_logprobs_fn: TypeCreateLogProbsFn,
previous_num_tokens = [0] * request.n request_id: str,
has_echoed = [False] * request.n created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
async for res in result_generator:
# TODO: handle client disconnect for streaming
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index + prompt_idx * request.n
delta_text = output.text[len(previous_texts[i]):] # TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
token_ids = output.token_ids[previous_num_tokens[i]:]
if request.logprobs is not None: if request.echo and request.max_tokens == 0:
top_logprobs = output.logprobs[previous_num_tokens[i]:] # only return the prompt
else: delta_text = res.prompt
top_logprobs = None delta_token_ids = res.prompt_token_ids
offsets = len(previous_texts[i]) top_logprobs = res.prompt_logprobs
if request.echo and not has_echoed[i]:
if not echo_without_generation:
delta_text = res.prompt + delta_text
token_ids = res.prompt_token_ids + token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs + top_logprobs
else: # only just return the prompt
delta_text = res.prompt
token_ids = res.prompt_token_ids
if top_logprobs:
top_logprobs = res.prompt_logprobs
has_echoed[i] = True has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None: if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn( logprobs = create_logprobs_fn(
token_ids=token_ids, token_ids=delta_token_ids,
top_logprobs=top_logprobs, top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs, num_output_top_logprobs=request.logprobs,
initial_text_offset=offsets, initial_text_offset=len(previous_texts[i]),
) )
else: else:
logprobs = None logprobs = None
previous_texts[i] = output.text previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids) previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason finish_reason = output.finish_reason
@ -77,7 +98,7 @@ async def completion_stream_generator(
]).model_dump_json(exclude_unset=True) ]).model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n" yield f"data: {response_json}\n\n"
if output.finish_reason is not None: if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids) prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids) completion_tokens = len(output.token_ids)
@ -129,51 +150,58 @@ def parse_prompt_format(prompt) -> tuple[bool, list]:
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
def request_output_to_completion_response(final_res: RequestOutput, request, def request_output_to_completion_response(
echo_without_generation, final_res_batch: list[RequestOutput],
create_logprobs_fn, request_id, request: CompletionRequest,
created_time, create_logprobs_fn: TypeCreateLogProbsFn,
model_name) -> CompletionResponse: request_id: str,
assert final_res is not None created_time: int,
model_name: str,
) -> CompletionResponse:
choices = [] choices = []
prompt_token_ids = final_res.prompt_token_ids num_prompt_tokens = 0
prompt_logprobs = final_res.prompt_logprobs num_generated_tokens = 0
prompt_text = final_res.prompt for final_res in final_res_batch:
for output in final_res.outputs: assert final_res is not None
if request.logprobs is not None: prompt_token_ids = final_res.prompt_token_ids
if not echo_without_generation: prompt_logprobs = final_res.prompt_logprobs
token_ids = output.token_ids prompt_text = final_res.prompt
top_logprobs = output.logprobs
if request.echo: for output in final_res.outputs:
token_ids = prompt_token_ids + token_ids if request.echo and request.max_tokens == 0:
top_logprobs = prompt_logprobs + top_logprobs
else:
token_ids = prompt_token_ids token_ids = prompt_token_ids
top_logprobs = prompt_logprobs top_logprobs = prompt_logprobs
logprobs = create_logprobs_fn( output_text = prompt_text
token_ids=token_ids, elif request.echo and request.max_tokens > 0:
top_logprobs=top_logprobs, token_ids = prompt_token_ids + output.token_ids
num_output_top_logprobs=request.logprobs, top_logprobs = prompt_logprobs + output.logprobs
) output_text = prompt_text + output.text
else: else:
logprobs = None token_ids = output.token_ids
if not echo_without_generation: top_logprobs = output.logprobs
output_text = output.text output_text = output.text
if request.echo:
output_text = prompt_text + output_text if request.logprobs is not None:
else: logprobs = create_logprobs_fn(
output_text = prompt_text token_ids=token_ids,
choice_data = CompletionResponseChoice( top_logprobs=top_logprobs,
index=output.index, num_output_top_logprobs=request.logprobs,
text=output_text, )
logprobs=logprobs, else:
finish_reason=output.finish_reason, logprobs = None
)
choices.append(choice_data) choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo( usage = UsageInfo(
prompt_tokens=num_prompt_tokens, prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens, completion_tokens=num_generated_tokens,
@ -189,6 +217,36 @@ def request_output_to_completion_response(final_res: RequestOutput, request,
) )
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
async for item in iterator:
await queue.put((i, item))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing): class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, served_model: str): def __init__(self, engine: AsyncLLMEngine, served_model: str):
@ -210,9 +268,6 @@ class OpenAIServingCompletion(OpenAIServing):
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
# OpenAI API supports echoing the prompt when max_tokens is 0.
echo_without_generation = request.echo and request.max_tokens == 0
# Return error for unsupported features. # Return error for unsupported features.
if request.suffix is not None: if request.suffix is not None:
return self.create_error_response( return self.create_error_response(
@ -226,30 +281,30 @@ class OpenAIServingCompletion(OpenAIServing):
created_time = int(time.monotonic()) created_time = int(time.monotonic())
# Schedule the request and get the result generator. # Schedule the request and get the result generator.
generators = []
try: try:
sampling_params = request.to_sampling_params() sampling_params = request.to_sampling_params()
prompt_is_tokens, prompts = parse_prompt_format(request.prompt) prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
if len(prompts) > 1: for i, prompt in enumerate(prompts):
raise ValueError( if prompt_is_tokens:
"Batching in completion API is not supported.") input_ids = self._validate_prompt_and_tokenize(
prompt = prompts[0] request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
if prompt_is_tokens: generators.append(
input_ids = self._validate_prompt_and_tokenize( self.engine.generate(None,
request, prompt_ids=prompt) sampling_params,
else: f"{request_id}-{i}",
input_ids = self._validate_prompt_and_tokenize(request, prompt_token_ids=input_ids))
prompt=prompt)
result_generator = self.engine.generate(None,
sampling_params,
request_id,
prompt_token_ids=input_ids)
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.
stream = (request.stream stream = (request.stream
@ -258,23 +313,27 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response # Streaming response
if stream: if stream:
return completion_stream_generator(request, result_generator, return completion_stream_generator(request,
echo_without_generation, raw_request,
self.engine.abort,
result_generator,
self._create_logprobs, self._create_logprobs,
request_id, created_time, request_id,
model_name) created_time,
model_name,
num_prompts=len(prompts))
# Non-streaming response # Non-streaming response
final_res: RequestOutput = None final_res_batch: RequestOutput = [None] * len(prompts)
async for res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(request_id) await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res = res final_res_batch[i] = res
response = request_output_to_completion_response( response = request_output_to_completion_response(
final_res, request, echo_without_generation, self._create_logprobs, final_res_batch, request, self._create_logprobs, request_id,
request_id, created_time, model_name) created_time, model_name)
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.