[Frontend] Support complex message content for chat completions endpoint (#3467)

Co-authored-by: Lily Liu <lilyliupku@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Florian Greinacher 2024-05-01 01:28:46 +02:00 committed by GitHub
parent 111815d482
commit a494140433
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 46 additions and 21 deletions

View File

@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": [{
"type":
"text",
"text":
"what is 1+1? please provide the result without any other text."
}]
}],
temperature=0,
seed=0)
content = resp.choices[0].message.content
assert content == "2"
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement

View File

@ -55,9 +55,16 @@ class OpenAIServingChat(OpenAIServing):
if isinstance(content, str):
return [ConversationMessage(role=role, content=content)], []
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
raise NotImplementedError("Complex input not supported yet")
texts: List[str] = []
for _, part in enumerate(content):
if part["type"] == "text":
text = part["text"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part['type']}")
return [ConversationMessage(role=role, content="\n".join(texts))], []
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
@ -122,11 +129,12 @@ class OpenAIServingChat(OpenAIServing):
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
request, result_generator, request_id, conversation)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
request, raw_request, result_generator, request_id,
conversation)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -139,8 +147,9 @@ class OpenAIServingChat(OpenAIServing):
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> AsyncGenerator[str, None]:
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
chunk_object_type = "chat.completion.chunk"
@ -179,12 +188,10 @@ class OpenAIServingChat(OpenAIServing):
# last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages,
list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(request.n):
@ -279,9 +286,10 @@ class OpenAIServingChat(OpenAIServing):
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = self.served_model_names[0]
created_time = int(time.time())
@ -322,11 +330,9 @@ class OpenAIServingChat(OpenAIServing):
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content