[Frontend] Support complex message content for chat completions endpoint (#3467)
Co-authored-by: Lily Liu <lilyliupku@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
parent
111815d482
commit
a494140433
@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
|||||||
assert "extra_forbidden" in exc_info.value.message
|
assert "extra_forbidden" in exc_info.value.message
|
||||||
|
|
||||||
|
|
||||||
|
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||||
|
resp = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [{
|
||||||
|
"type":
|
||||||
|
"text",
|
||||||
|
"text":
|
||||||
|
"what is 1+1? please provide the result without any other text."
|
||||||
|
}]
|
||||||
|
}],
|
||||||
|
temperature=0,
|
||||||
|
seed=0)
|
||||||
|
content = resp.choices[0].message.content
|
||||||
|
assert content == "2"
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||||
simple_sql_grammar = """
|
simple_sql_grammar = """
|
||||||
start: select_statement
|
start: select_statement
|
||||||
|
|||||||
@ -55,9 +55,16 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
return [ConversationMessage(role=role, content=content)], []
|
return [ConversationMessage(role=role, content=content)], []
|
||||||
|
|
||||||
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
|
texts: List[str] = []
|
||||||
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
|
for _, part in enumerate(content):
|
||||||
raise NotImplementedError("Complex input not supported yet")
|
if part["type"] == "text":
|
||||||
|
text = part["text"]
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unknown part type: {part['type']}")
|
||||||
|
|
||||||
|
return [ConversationMessage(role=role, content="\n".join(texts))], []
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, request: ChatCompletionRequest, raw_request: Request
|
self, request: ChatCompletionRequest, raw_request: Request
|
||||||
@ -122,11 +129,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Streaming response
|
# Streaming response
|
||||||
if request.stream:
|
if request.stream:
|
||||||
return self.chat_completion_stream_generator(
|
return self.chat_completion_stream_generator(
|
||||||
request, result_generator, request_id)
|
request, result_generator, request_id, conversation)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return await self.chat_completion_full_generator(
|
return await self.chat_completion_full_generator(
|
||||||
request, raw_request, result_generator, request_id)
|
request, raw_request, result_generator, request_id,
|
||||||
|
conversation)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -139,8 +147,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
async def chat_completion_stream_generator(
|
async def chat_completion_stream_generator(
|
||||||
self, request: ChatCompletionRequest,
|
self, request: ChatCompletionRequest,
|
||||||
result_generator: AsyncIterator[RequestOutput],
|
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||||
request_id: str) -> AsyncGenerator[str, None]:
|
conversation: List[ConversationMessage]
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
chunk_object_type = "chat.completion.chunk"
|
chunk_object_type = "chat.completion.chunk"
|
||||||
@ -179,12 +188,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# last message
|
# last message
|
||||||
if request.echo:
|
if request.echo:
|
||||||
last_msg_content = ""
|
last_msg_content = ""
|
||||||
if request.messages and isinstance(
|
if conversation and conversation[-1].get(
|
||||||
request.messages,
|
"content") and conversation[-1].get(
|
||||||
list) and request.messages[-1].get(
|
"role") == role:
|
||||||
"content") and request.messages[-1].get(
|
last_msg_content = conversation[-1]["content"]
|
||||||
"role") == role:
|
|
||||||
last_msg_content = request.messages[-1]["content"]
|
|
||||||
|
|
||||||
if last_msg_content:
|
if last_msg_content:
|
||||||
for i in range(request.n):
|
for i in range(request.n):
|
||||||
@ -279,9 +286,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
async def chat_completion_full_generator(
|
async def chat_completion_full_generator(
|
||||||
self, request: ChatCompletionRequest, raw_request: Request,
|
self, request: ChatCompletionRequest, raw_request: Request,
|
||||||
result_generator: AsyncIterator[RequestOutput],
|
result_generator: AsyncIterator[RequestOutput], request_id: str,
|
||||||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
conversation: List[ConversationMessage]
|
||||||
|
) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||||
|
|
||||||
model_name = self.served_model_names[0]
|
model_name = self.served_model_names[0]
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
@ -322,11 +330,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
last_msg_content = ""
|
last_msg_content = ""
|
||||||
if request.messages and isinstance(
|
if conversation and conversation[-1].get(
|
||||||
request.messages, list) and request.messages[-1].get(
|
"content") and conversation[-1].get("role") == role:
|
||||||
"content") and request.messages[-1].get(
|
last_msg_content = conversation[-1]["content"]
|
||||||
"role") == role:
|
|
||||||
last_msg_content = request.messages[-1]["content"]
|
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
full_message = last_msg_content + choice.message.content
|
full_message = last_msg_content + choice.message.content
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user