[Frontend] Support complex message content for chat completions endpoint (#3467)
Co-authored-by: Lily Liu <lilyliupku@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
111815d482
commit
a494140433
@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
assert "extra_forbidden" in exc_info.value.message
|
||||
|
||||
|
||||
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type":
|
||||
"text",
|
||||
"text":
|
||||
"what is 1+1? please provide the result without any other text."
|
||||
}]
|
||||
}],
|
||||
temperature=0,
|
||||
seed=0)
|
||||
content = resp.choices[0].message.content
|
||||
assert content == "2"
|
||||
|
||||
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
|
||||
@ -55,9 +55,16 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if isinstance(content, str):
|
||||
return [ConversationMessage(role=role, content=content)], []
|
||||
|
||||
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
|
||||
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
|
||||
raise NotImplementedError("Complex input not supported yet")
|
||||
texts: List[str] = []
|
||||
for _, part in enumerate(content):
|
||||
if part["type"] == "text":
|
||||
text = part["text"]
|
||||
|
||||
texts.append(text)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown part type: {part['type']}")
|
||||
|
||||
return [ConversationMessage(role=role, content="\n".join(texts))], []
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: ChatCompletionRequest, raw_request: Request
|
||||
@ -122,11 +129,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id)
|
||||
request, result_generator, request_id, conversation)
|
||||
else:
|
||||
try:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id)
|
||||
request, raw_request, result_generator, request_id,
|
||||
conversation)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -139,8 +147,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str) -> AsyncGenerator[str, None]:
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
@ -179,12 +188,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages,
|
||||
list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = conversation[-1]["content"]
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
@ -279,9 +286,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self, request: ChatCompletionRequest, raw_request: Request,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
self, request: ChatCompletionRequest, raw_request: Request,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||
conversation: List[ConversationMessage]
|
||||
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = self.served_model_names[0]
|
||||
created_time = int(time.time())
|
||||
@ -322,11 +330,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
if conversation and conversation[-1].get(
|
||||
"content") and conversation[-1].get("role") == role:
|
||||
last_msg_content = conversation[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
|
||||
Loading…
Reference in New Issue
Block a user