diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 707ab6d2..3cef6bfd 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -88,6 +88,16 @@ async def test_single_completion(server, client: openai.AsyncOpenAI): assert completion.usage == openai.types.CompletionUsage( completion_tokens=5, prompt_tokens=6, total_tokens=11) + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + async def test_single_chat_session(server, client: openai.AsyncOpenAI): messages = [{ diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7a86a19c..6a24e7e9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,6 +6,7 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field from vllm.utils import random_uuid +from vllm.sampling_params import SamplingParams class ErrorResponse(BaseModel): @@ -78,6 +79,26 @@ class ChatCompletionRequest(BaseModel): repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + def to_sampling_params(self) -> SamplingParams: + return SamplingParams( + n=self.n, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + min_p=self.min_p, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + max_tokens=self.max_tokens, + best_of=self.best_of, + top_k=self.top_k, + ignore_eos=self.ignore_eos, + use_beam_search=self.use_beam_search, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + ) + class CompletionRequest(BaseModel): model: str @@ -107,6 +128,30 @@ class CompletionRequest(BaseModel): repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + def to_sampling_params(self): + echo_without_generation = self.echo and self.max_tokens == 0 + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens if not echo_without_generation else 1, + logprobs=self.logprobs, + use_beam_search=self.use_beam_search, + prompt_logprobs=self.logprobs if self.echo else None, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=(self.spaces_between_special_tokens), + ) + class LogProbs(BaseModel): text_offset: List[int] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9b843a94..83d70e02 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -11,7 +11,6 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams from vllm.entrypoints.openai.serving_engine import OpenAIServing logger = init_logger(__name__) @@ -60,32 +59,11 @@ class OpenAIServingChat(OpenAIServing): f"Error in applying chat template from request: {str(e)}") return self.create_error_response(str(e)) - token_ids, error_check_ret = await self._check_length(request, - prompt=prompt) - if error_check_ret is not None: - return error_check_ret - request_id = f"cmpl-{random_uuid()}" try: - spaces_between_special_tokens = request.spaces_between_special_tokens - sampling_params = SamplingParams( - n=request.n, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - min_p=request.min_p, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - max_tokens=request.max_tokens, - best_of=request.best_of, - top_k=request.top_k, - ignore_eos=request.ignore_eos, - use_beam_search=request.use_beam_search, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) + token_ids = self._validate_prompt_and_tokenize(request, + prompt=prompt) + sampling_params = request.to_sampling_params() except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d842d1a2..d668ed50 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,20 +1,194 @@ import time from fastapi import Request -from typing import AsyncGenerator, Optional +from typing import AsyncGenerator, AsyncIterator from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine -from .protocol import (CompletionRequest, CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, LogProbs, UsageInfo) +from .protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + LogProbs, + UsageInfo, +) from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams from vllm.entrypoints.openai.serving_engine import OpenAIServing logger = init_logger(__name__) +async def completion_stream_generator( + request: CompletionRequest, + result_generator: AsyncIterator[RequestOutput], + echo_without_generation, create_logprobs_fn, request_id, created_time, + model_name) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + has_echoed = [False] * request.n + + async for res in result_generator: + # TODO: handle client disconnect for streaming + for output in res.outputs: + i = output.index + delta_text = output.text[len(previous_texts[i]):] + token_ids = output.token_ids[previous_num_tokens[i]:] + if request.logprobs is not None: + top_logprobs = output.logprobs[previous_num_tokens[i]:] + else: + top_logprobs = None + offsets = len(previous_texts[i]) + if request.echo and not has_echoed[i]: + if not echo_without_generation: + delta_text = res.prompt + delta_text + token_ids = res.prompt_token_ids + token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + top_logprobs + else: # only just return the prompt + delta_text = res.prompt + token_ids = res.prompt_token_ids + if top_logprobs: + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + if request.logprobs is not None: + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=offsets, + ) + else: + logprobs = None + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + ]).json(exclude_unset=True, ensure_ascii=False) + yield f"data: {response_json}\n\n" + + if output.finish_reason is not None: + logprobs = LogProbs() if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + ], + usage=final_usage, + ).json(exclude_unset=True, ensure_ascii=False) + yield f"data: {response_json}\n\n" + + yield "data: [DONE]\n\n" + + +def parse_prompt_format(prompt) -> tuple[bool, list]: + # get the prompt, openai supports the following + # "a string, array of strings, array of tokens, or array of token arrays." + prompt_is_tokens = False + prompts = [prompt] # case 1: a string + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + elif isinstance(prompt[0], str): + prompt_is_tokens = False + prompts = prompt # case 2: array of strings + elif isinstance(prompt[0], int): + prompt_is_tokens = True + prompts = [prompt] # case 3: array of tokens + elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): + prompt_is_tokens = True + prompts = prompt # case 4: array of token arrays + else: + raise ValueError( + "prompt must be a string, array of strings, array of tokens, or array of token arrays" + ) + return prompt_is_tokens, prompts + + +def request_output_to_completion_response(final_res: RequestOutput, request, + echo_without_generation, + create_logprobs_fn, request_id, + created_time, + model_name) -> CompletionResponse: + assert final_res is not None + choices = [] + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + for output in final_res.outputs: + if request.logprobs is not None: + if not echo_without_generation: + token_ids = output.token_ids + top_logprobs = output.logprobs + if request.echo: + token_ids = prompt_token_ids + token_ids + top_logprobs = prompt_logprobs + top_logprobs + else: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + if not echo_without_generation: + output_text = output.text + if request.echo: + output_text = prompt_text + output_text + else: + output_text = prompt_text + choice_data = CompletionResponseChoice( + index=output.index, + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str): @@ -32,7 +206,6 @@ class OpenAIServingCompletion(OpenAIServing): suffix) - logit_bias (to be supported by vLLM engine) """ - error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -40,83 +213,42 @@ class OpenAIServingCompletion(OpenAIServing): # OpenAI API supports echoing the prompt when max_tokens is 0. echo_without_generation = request.echo and request.max_tokens == 0 + # Return error for unsupported features. if request.suffix is not None: - # The language models we currently support do not support suffix. return self.create_error_response( "suffix is not currently supported") - if request.logit_bias is not None and len(request.logit_bias) > 0: - # TODO: support logit_bias in vLLM engine. return self.create_error_response( "logit_bias is not currently supported") model_name = request.model request_id = f"cmpl-{random_uuid()}" - - use_token_ids = False - if isinstance(request.prompt, list): - if len(request.prompt) == 0: - return self.create_error_response( - "please provide at least one prompt") - first_element = request.prompt[0] - if isinstance(first_element, int): - use_token_ids = True - prompt = request.prompt - elif isinstance(first_element, (str, list)): - # TODO: handles multiple prompt case in list[list[int]] - if len(request.prompt) > 1: - return self.create_error_response( - "multiple prompts in a batch is not currently supported" - ) - use_token_ids = not isinstance(first_element, str) - prompt = request.prompt[0] - else: - prompt = request.prompt - - if use_token_ids: - _, error_check_ret = await self._check_length(request, - prompt_ids=prompt) - else: - token_ids, error_check_ret = await self._check_length( - request, prompt=prompt) - if error_check_ret is not None: - return error_check_ret - created_time = int(time.monotonic()) - try: - spaces_between_special_tokens = request.spaces_between_special_tokens - sampling_params = SamplingParams( - n=request.n, - best_of=request.best_of, - presence_penalty=request.presence_penalty, - frequency_penalty=request.frequency_penalty, - repetition_penalty=request.repetition_penalty, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - min_p=request.min_p, - stop=request.stop, - stop_token_ids=request.stop_token_ids, - ignore_eos=request.ignore_eos, - max_tokens=request.max_tokens - if not echo_without_generation else 1, - logprobs=request.logprobs, - use_beam_search=request.use_beam_search, - prompt_logprobs=request.logprobs if request.echo else None, - skip_special_tokens=request.skip_special_tokens, - spaces_between_special_tokens=spaces_between_special_tokens, - ) - except ValueError as e: - return self.create_error_response(str(e)) - if use_token_ids: + # Schedule the request and get the result generator. + try: + sampling_params = request.to_sampling_params() + + prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + + if len(prompts) > 1: + raise ValueError( + "Batching in completion API is not supported.") + prompt = prompts[0] + + if prompt_is_tokens: + input_ids = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + input_ids = self._validate_prompt_and_tokenize(request, + prompt=prompt) + result_generator = self.engine.generate(None, sampling_params, request_id, - prompt_token_ids=prompt) - else: - result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids) + prompt_token_ids=input_ids) + except ValueError as e: + return self.create_error_response(str(e)) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. @@ -124,101 +256,13 @@ class OpenAIServingCompletion(OpenAIServing): and (request.best_of is None or request.n == request.best_of) and not request.use_beam_search) - def create_stream_response_json( - index: int, - text: str, - logprobs: Optional[LogProbs] = None, - finish_reason: Optional[str] = None, - usage: Optional[UsageInfo] = None, - ) -> str: - choice_data = CompletionResponseStreamChoice( - index=index, - text=text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - response = CompletionStreamResponse( - id=request_id, - created=created_time, - model=model_name, - choices=[choice_data], - ) - if usage is not None: - response.usage = usage - response_json = response.json(exclude_unset=True, - ensure_ascii=False) - - return response_json - - async def completion_stream_generator() -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - has_echoed = [False] * request.n - async for res in result_generator: - res: RequestOutput - for output in res.outputs: - i = output.index - delta_text = output.text[len(previous_texts[i]):] - token_ids = output.token_ids[previous_num_tokens[i]:] - if request.logprobs is not None: - top_logprobs = output.logprobs[previous_num_tokens[i]:] - else: - top_logprobs = None - offsets = len(previous_texts[i]) - if request.echo and not has_echoed[i]: - if not echo_without_generation: - delta_text = res.prompt + delta_text - token_ids = res.prompt_token_ids + token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs + top_logprobs - else: # only just return the prompt - delta_text = res.prompt - token_ids = res.prompt_token_ids - if top_logprobs: - top_logprobs = res.prompt_logprobs - has_echoed[i] = True - if request.logprobs is not None: - logprobs = self._create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - initial_text_offset=offsets, - ) - else: - logprobs = None - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) - finish_reason = output.finish_reason - response_json = create_stream_response_json( - index=i, - text=delta_text, - logprobs=logprobs, - finish_reason=finish_reason, - ) - yield f"data: {response_json}\n\n" - if output.finish_reason is not None: - logprobs = (LogProbs() - if request.logprobs is not None else None) - prompt_tokens = len(res.prompt_token_ids) - completion_tokens = len(output.token_ids) - final_usage = UsageInfo( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - response_json = create_stream_response_json( - index=i, - text="", - logprobs=logprobs, - finish_reason=output.finish_reason, - usage=final_usage, - ) - yield f"data: {response_json}\n\n" - yield "data: [DONE]\n\n" - # Streaming response if stream: - return completion_stream_generator() + return completion_stream_generator(request, result_generator, + echo_without_generation, + self._create_logprobs, + request_id, created_time, + model_name) # Non-streaming response final_res: RequestOutput = None @@ -228,62 +272,13 @@ class OpenAIServingCompletion(OpenAIServing): await self.engine.abort(request_id) return self.create_error_response("Client disconnected") final_res = res - assert final_res is not None - choices = [] - prompt_token_ids = final_res.prompt_token_ids - prompt_logprobs = final_res.prompt_logprobs - prompt_text = final_res.prompt - for output in final_res.outputs: - if request.logprobs is not None: - if not echo_without_generation: - token_ids = output.token_ids - top_logprobs = output.logprobs - if request.echo: - token_ids = prompt_token_ids + token_ids - top_logprobs = prompt_logprobs + top_logprobs - else: - token_ids = prompt_token_ids - top_logprobs = prompt_logprobs - logprobs = self._create_logprobs( - token_ids=token_ids, - top_logprobs=top_logprobs, - num_output_top_logprobs=request.logprobs, - ) - else: - logprobs = None - if not echo_without_generation: - output_text = output.text - if request.echo: - output_text = prompt_text + output_text - else: - output_text = prompt_text - choice_data = CompletionResponseChoice( - index=output.index, - text=output_text, - logprobs=logprobs, - finish_reason=output.finish_reason, - ) - choices.append(choice_data) - - num_prompt_tokens = len(final_res.prompt_token_ids) - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - completion_tokens=num_generated_tokens, - total_tokens=num_prompt_tokens + num_generated_tokens, - ) - response = CompletionResponse( - id=request_id, - created=created_time, - model=model_name, - choices=choices, - usage=usage, - ) + response = request_output_to_completion_response( + final_res, request, echo_without_generation, self._create_logprobs, + request_id, created_time, model_name) + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. if request.stream: - # When user requests streaming but we don't stream, we still need to - # return a streaming response with a single event. response_json = response.json(ensure_ascii=False) async def fake_stream_generator() -> AsyncGenerator[str, None]: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index e77a0720..390f9aeb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,6 +1,6 @@ import asyncio from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -104,27 +104,30 @@ class OpenAIServing: err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - async def _check_length( - self, - request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None - ) -> Tuple[List[int], Optional[ErrorResponse]]: - assert (not (prompt is None and prompt_ids is None) - and not (prompt is not None and prompt_ids is not None) - ), "Either prompt or prompt_ids should be provided." + def _validate_prompt_and_tokenize( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None) -> List[int]: + if not (prompt or prompt_ids): + raise ValueError("Either prompt or prompt_ids should be provided.") + if (prompt and prompt_ids): + raise ValueError( + "Only one of prompt or prompt_ids should be provided.") + input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( prompt).input_ids token_num = len(input_ids) if request.max_tokens is None: request.max_tokens = self.max_model_len - token_num + if token_num + request.max_tokens > self.max_model_len: - return input_ids, self.create_error_response( + raise ValueError( f"This model's maximum context length is {self.max_model_len} tokens. " f"However, you requested {request.max_tokens + token_num} tokens " f"({token_num} in the messages, " f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) else: - return input_ids, None + return input_ids diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 3b3958d6..b725122c 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -163,7 +163,7 @@ def prepare_hf_model_weights( use_safetensors = True break - logger.info(f"Downloading model weights {allow_patterns}") + logger.info(f"Using model weights format {allow_patterns}") # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir):