[Frontend] OpenAI server: propagate usage accounting to FastAPI middleware layer (#8672)
This commit is contained in:
parent
3e073e66f1
commit
1ac3de09cd
@ -107,6 +107,11 @@ class UsageInfo(OpenAIBaseModel):
|
|||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RequestResponseMetadata(BaseModel):
|
||||||
|
request_id: str
|
||||||
|
final_usage_info: Optional[UsageInfo] = None
|
||||||
|
|
||||||
|
|
||||||
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
class JsonSchemaResponseFormat(OpenAIBaseModel):
|
||||||
name: str
|
name: str
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
|
|||||||
@ -22,7 +22,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo)
|
DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata,
|
||||||
|
ToolCall, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
LoRAModulePath,
|
LoRAModulePath,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
@ -175,6 +176,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
"--enable-auto-tool-choice and --tool-call-parser to be set")
|
||||||
|
|
||||||
request_id = f"chat-{random_uuid()}"
|
request_id = f"chat-{random_uuid()}"
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
try:
|
try:
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await self._guided_decode_logits_processor(request, tokenizer))
|
await self._guided_decode_logits_processor(request, tokenizer))
|
||||||
@ -241,11 +247,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.chat_completion_stream_generator(
|
return self.chat_completion_stream_generator(
|
||||||
request, result_generator, request_id, conversation, tokenizer)
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.chat_completion_full_generator(
|
return await self.chat_completion_full_generator(
|
||||||
request, result_generator, request_id, conversation, tokenizer)
|
request, result_generator, request_id, conversation, tokenizer,
|
||||||
|
request_metadata)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -262,6 +270,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
model_name = self.base_model_paths[0].name
|
model_name = self.base_model_paths[0].name
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
@ -580,6 +589,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
exclude_unset=True, exclude_none=True))
|
exclude_unset=True, exclude_none=True))
|
||||||
yield f"data: {final_usage_data}\n\n"
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
|
num_completion_tokens = sum(previous_num_tokens)
|
||||||
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
|
prompt_tokens=num_prompt_tokens,
|
||||||
|
completion_tokens=num_completion_tokens,
|
||||||
|
total_tokens=num_prompt_tokens + num_completion_tokens)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
logger.error("error in chat completion stream generator: %s", e)
|
logger.error("error in chat completion stream generator: %s", e)
|
||||||
@ -595,6 +611,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
conversation: List[ConversationMessage],
|
conversation: List[ConversationMessage],
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = self.base_model_paths[0].name
|
model_name = self.base_model_paths[0].name
|
||||||
@ -714,6 +731,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
completion_tokens=num_generated_tokens,
|
completion_tokens=num_generated_tokens,
|
||||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
|
|||||||
@ -18,7 +18,9 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
ErrorResponse, UsageInfo)
|
ErrorResponse,
|
||||||
|
RequestResponseMetadata,
|
||||||
|
UsageInfo)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||||
LoRAModulePath,
|
LoRAModulePath,
|
||||||
@ -94,6 +96,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
request_id = f"cmpl-{random_uuid()}"
|
request_id = f"cmpl-{random_uuid()}"
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
|
|
||||||
|
request_metadata = RequestResponseMetadata(request_id=request_id)
|
||||||
|
if raw_request:
|
||||||
|
raw_request.state.request_metadata = request_metadata
|
||||||
|
|
||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
generators: List[AsyncGenerator[RequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
@ -165,13 +171,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
return self.completion_stream_generator(request,
|
return self.completion_stream_generator(
|
||||||
result_generator,
|
request,
|
||||||
request_id,
|
result_generator,
|
||||||
created_time,
|
request_id,
|
||||||
model_name,
|
created_time,
|
||||||
num_prompts=len(prompts),
|
model_name,
|
||||||
tokenizer=tokenizer)
|
num_prompts=len(prompts),
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
request_metadata=request_metadata)
|
||||||
|
|
||||||
# Non-streaming response
|
# Non-streaming response
|
||||||
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts)
|
||||||
@ -198,6 +206,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time,
|
created_time,
|
||||||
model_name,
|
model_name,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
request_metadata,
|
||||||
)
|
)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
return self.create_error_response("Client disconnected")
|
return self.create_error_response("Client disconnected")
|
||||||
@ -227,6 +236,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
num_prompts: int,
|
num_prompts: int,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
num_choices = 1 if request.n is None else request.n
|
num_choices = 1 if request.n is None else request.n
|
||||||
previous_text_lens = [0] * num_choices * num_prompts
|
previous_text_lens = [0] * num_choices * num_prompts
|
||||||
@ -346,6 +356,14 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
exclude_unset=False, exclude_none=True))
|
exclude_unset=False, exclude_none=True))
|
||||||
yield f"data: {final_usage_data}\n\n"
|
yield f"data: {final_usage_data}\n\n"
|
||||||
|
|
||||||
|
# report to FastAPI middleware aggregate usage across all choices
|
||||||
|
total_prompt_tokens = sum(num_prompt_tokens)
|
||||||
|
total_completion_tokens = sum(previous_num_tokens)
|
||||||
|
request_metadata.final_usage_info = UsageInfo(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=total_completion_tokens,
|
||||||
|
total_tokens=total_prompt_tokens + total_completion_tokens)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
data = self.create_streaming_error_response(str(e))
|
data = self.create_streaming_error_response(str(e))
|
||||||
@ -360,6 +378,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
created_time: int,
|
created_time: int,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
|
request_metadata: RequestResponseMetadata,
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
choices: List[CompletionResponseChoice] = []
|
choices: List[CompletionResponseChoice] = []
|
||||||
num_prompt_tokens = 0
|
num_prompt_tokens = 0
|
||||||
@ -433,6 +452,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
request_metadata.final_usage_info = usage
|
||||||
|
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
id=request_id,
|
id=request_id,
|
||||||
created=created_time,
|
created=created_time,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user