diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7e9f53b1..40d27f98 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + class JsonSchemaResponseFormat(OpenAIBaseModel): name: str description: Optional[str] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ee4b3ce..0321ea98 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, - DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, OpenAIServing, @@ -175,6 +176,11 @@ class OpenAIServingChat(OpenAIServing): "--enable-auto-tool-choice and --tool-call-parser to be set") request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + try: guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) @@ -241,11 +247,13 @@ class OpenAIServingChat(OpenAIServing): # Streaming response if request.stream: return self.chat_completion_stream_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) try: return await self.chat_completion_full_generator( - request, result_generator, request_id, conversation, tokenizer) + request, result_generator, request_id, conversation, tokenizer, + request_metadata) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -262,6 +270,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: model_name = self.base_model_paths[0].name created_time = int(time.time()) @@ -580,6 +589,13 @@ class OpenAIServingChat(OpenAIServing): exclude_unset=True, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error logger.error("error in chat completion stream generator: %s", e) @@ -595,6 +611,7 @@ class OpenAIServingChat(OpenAIServing): request_id: str, conversation: List[ConversationMessage], tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.base_model_paths[0].name @@ -714,6 +731,9 @@ class OpenAIServingChat(OpenAIServing): completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) + + request_metadata.final_usage_info = usage + response = ChatCompletionResponse( id=request_id, created=created_time, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 9abd74d0..0e860900 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -18,7 +18,9 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - ErrorResponse, UsageInfo) + ErrorResponse, + RequestResponseMetadata, + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (BaseModelPath, LoRAModulePath, @@ -94,6 +96,10 @@ class OpenAIServingCompletion(OpenAIServing): request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -165,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing): # Streaming response if stream: - return self.completion_stream_generator(request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(prompts), - tokenizer=tokenizer) + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts), + tokenizer=tokenizer, + request_metadata=request_metadata) # Non-streaming response final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) @@ -198,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing): created_time, model_name, tokenizer, + request_metadata, ) except asyncio.CancelledError: return self.create_error_response("Client disconnected") @@ -227,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing): model_name: str, num_prompts: int, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_text_lens = [0] * num_choices * num_prompts @@ -346,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing): exclude_unset=False, exclude_none=True)) yield f"data: {final_usage_data}\n\n" + # report to FastAPI middleware aggregate usage across all choices + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) @@ -360,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing): created_time: int, model_name: str, tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -433,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing): total_tokens=num_prompt_tokens + num_generated_tokens, ) + request_metadata.final_usage_info = usage + return CompletionResponse( id=request_id, created=created_time,