[Frontend] Re-enable custom roles in Chat Completions API (#4758)
This commit is contained in:
parent
361c461a12
commit
fc0d9dfc3a
@ -783,6 +783,36 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
|||||||
assert content == "2"
|
assert content == "2"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||||
|
# Not sure how the model handles custom roles so we just check that
|
||||||
|
# both string and complex message content are handled in the same way
|
||||||
|
|
||||||
|
resp1 = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "my-custom-role",
|
||||||
|
"content": "what is 1+1?",
|
||||||
|
}], # type: ignore
|
||||||
|
temperature=0,
|
||||||
|
seed=0)
|
||||||
|
|
||||||
|
resp2 = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "my-custom-role",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
}], # type: ignore
|
||||||
|
temperature=0,
|
||||||
|
seed=0)
|
||||||
|
|
||||||
|
content1 = resp1.choices[0].message.content
|
||||||
|
content2 = resp2.choices[0].message.content
|
||||||
|
assert content1 == content2
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||||
simple_sql_grammar = """
|
simple_sql_grammar = """
|
||||||
start: select_statement
|
start: select_statement
|
||||||
|
|||||||
@ -3,16 +3,50 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Literal, Optional, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import openai.types.chat
|
||||||
import torch
|
import torch
|
||||||
from openai.types.chat import ChatCompletionMessageParam
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
from typing_extensions import Annotated
|
# pydantic needs the TypedDict from typing_extensions
|
||||||
|
from typing_extensions import Annotated, Required, TypedDict
|
||||||
|
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionContentPartParam(TypedDict, total=False):
|
||||||
|
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore
|
||||||
|
|
||||||
|
type: Required[str]
|
||||||
|
"""The type of the content part."""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionContentPartParam = Union[
|
||||||
|
openai.types.chat.ChatCompletionContentPartParam,
|
||||||
|
CustomChatCompletionContentPartParam]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||||
|
"""Enables custom roles in the Chat Completion API."""
|
||||||
|
role: Required[str]
|
||||||
|
"""The role of the message's author."""
|
||||||
|
|
||||||
|
content: Union[str, List[ChatCompletionContentPartParam]]
|
||||||
|
"""The contents of the message."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
"""An optional name for the participant.
|
||||||
|
|
||||||
|
Provides the model information to differentiate between participants of the
|
||||||
|
same role.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageParam = Union[
|
||||||
|
openai.types.chat.ChatCompletionMessageParam,
|
||||||
|
CustomChatCompletionMessageParam]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIBaseModel(BaseModel):
|
class OpenAIBaseModel(BaseModel):
|
||||||
# OpenAI API does not allow extra fields
|
# OpenAI API does not allow extra fields
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|||||||
@ -1,15 +1,16 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List,
|
from dataclasses import dataclass
|
||||||
Optional, Tuple, TypedDict, Union, final)
|
from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional,
|
||||||
|
TypedDict, Union, cast, final)
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from openai.types.chat import (ChatCompletionContentPartParam,
|
from openai.types.chat import ChatCompletionContentPartTextParam
|
||||||
ChatCompletionRole)
|
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionContentPartParam, ChatCompletionMessageParam,
|
||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
@ -31,6 +32,11 @@ class ConversationMessage(TypedDict):
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ChatMessageParseResult:
|
||||||
|
messages: List[ConversationMessage]
|
||||||
|
|
||||||
|
|
||||||
class OpenAIServingChat(OpenAIServing):
|
class OpenAIServingChat(OpenAIServing):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -77,27 +83,40 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"No chat template provided. Chat API will not work.")
|
"No chat template provided. Chat API will not work.")
|
||||||
|
|
||||||
def _parse_chat_message_content(
|
def _parse_chat_message_content_parts(
|
||||||
self,
|
self,
|
||||||
role: ChatCompletionRole,
|
role: str,
|
||||||
content: Optional[Union[str,
|
parts: Iterable[ChatCompletionContentPartParam],
|
||||||
Iterable[ChatCompletionContentPartParam]]],
|
) -> ChatMessageParseResult:
|
||||||
) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]:
|
|
||||||
if content is None:
|
|
||||||
return [], []
|
|
||||||
if isinstance(content, str):
|
|
||||||
return [ConversationMessage(role=role, content=content)], []
|
|
||||||
|
|
||||||
texts: List[str] = []
|
texts: List[str] = []
|
||||||
for _, part in enumerate(content):
|
|
||||||
if part["type"] == "text":
|
for _, part in enumerate(parts):
|
||||||
text = part["text"]
|
part_type = part["type"]
|
||||||
|
if part_type == "text":
|
||||||
|
text = cast(ChatCompletionContentPartTextParam, part)["text"]
|
||||||
|
|
||||||
texts.append(text)
|
texts.append(text)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown part type: {part['type']}")
|
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||||
|
|
||||||
return [ConversationMessage(role=role, content="\n".join(texts))], []
|
messages = [ConversationMessage(role=role, content="\n".join(texts))]
|
||||||
|
|
||||||
|
return ChatMessageParseResult(messages=messages)
|
||||||
|
|
||||||
|
def _parse_chat_message_content(
|
||||||
|
self,
|
||||||
|
message: ChatCompletionMessageParam,
|
||||||
|
) -> ChatMessageParseResult:
|
||||||
|
role = message["role"]
|
||||||
|
content = message.get("content")
|
||||||
|
|
||||||
|
if content is None:
|
||||||
|
return ChatMessageParseResult(messages=[])
|
||||||
|
if isinstance(content, str):
|
||||||
|
messages = [ConversationMessage(role=role, content=content)]
|
||||||
|
return ChatMessageParseResult(messages=messages)
|
||||||
|
|
||||||
|
return self._parse_chat_message_content_parts(role, content)
|
||||||
|
|
||||||
async def create_chat_completion(
|
async def create_chat_completion(
|
||||||
self, request: ChatCompletionRequest, raw_request: Request
|
self, request: ChatCompletionRequest, raw_request: Request
|
||||||
@ -119,11 +138,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
conversation: List[ConversationMessage] = []
|
conversation: List[ConversationMessage] = []
|
||||||
|
|
||||||
for m in request.messages:
|
for msg in request.messages:
|
||||||
messages, _ = self._parse_chat_message_content(
|
parsed_msg = self._parse_chat_message_content(msg)
|
||||||
m["role"], m["content"])
|
|
||||||
|
|
||||||
conversation.extend(messages)
|
conversation.extend(parsed_msg.messages)
|
||||||
|
|
||||||
prompt = self.tokenizer.apply_chat_template(
|
prompt = self.tokenizer.apply_chat_template(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user