[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-08-07 17:12:05 +08:00 committed by GitHub
parent 7b261092de
commit 66d617e343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 125 additions and 69 deletions

View File

@ -1,22 +1,16 @@
import os
import pathlib
import pytest import pytest
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( from ..utils import VLLM_PATH
__file__))).parent.parent / "examples/template_chatml.jinja"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists() assert chatml_jinja_path.exists()
# Define models, templates, and their corresponding expected outputs # Define models, templates, and their corresponding expected outputs
MODEL_TEMPLATE_GENERATON_OUTPUT = [ MODEL_TEMPLATE_GENERATON_OUTPUT = [
("facebook/opt-125m", None, True,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", None, False,
"Hello</s>Hi there!</s>What is the capital of</s>"),
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
Hello<|im_end|> Hello<|im_end|>
<|im_start|>assistant <|im_start|>assistant
@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt) add_generation_prompt=add_generation_prompt)
# Call the function and get the result # Call the function and get the result
result = tokenizer.apply_chat_template( result = apply_chat_template(
tokenizer,
conversation=mock_request.messages, conversation=mock_request.messages,
tokenize=False, chat_template=mock_request.chat_template or template_content,
add_generation_prompt=mock_request.add_generation_prompt, add_generation_prompt=mock_request.add_generation_prompt,
chat_template=mock_request.chat_template or template_content) )
# Test assertion # Test assertion
assert result == expected_output, ( assert result == expected_output, (

View File

@ -1,10 +1,12 @@
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
from ..utils import RemoteOpenAIServer from ..utils import VLLM_PATH, RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -16,7 +18,9 @@ def server():
"--max-model-len", "--max-model-len",
"2048", "2048",
"--enforce-eager", "--enforce-eager",
"--engine-use-ray" "--engine-use-ray",
"--chat-template",
str(chatml_jinja_path),
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
choice = chat_completion.choices[0] choice = chat_completion.choices[0]
assert choice.finish_reason == "length" assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage( assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=13, total_tokens=23) completion_tokens=10, prompt_tokens=55, total_tokens=65)
message = choice.message message = choice.message
assert message.content is not None and len(message.content) >= 10 assert message.content is not None and len(message.content) >= 10

View File

@ -9,6 +9,11 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port from vllm.utils import get_open_port
from ...utils import VLLM_PATH, RemoteOpenAIServer
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
@ -21,12 +26,25 @@ class MyOPTForCausalLM(OPTForCausalLM):
return logits return logits
def server_function(port): def server_function(port: int):
# register our dummy model # register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --gpu-memory-utilization 0.10 " sys.argv = ["placeholder.py"] + [
f"--dtype float32 --api-key token-abc123 --port {port}").split() "--model",
"facebook/opt-125m",
"--gpu-memory-utilization",
"0.10",
"--dtype",
"float32",
"--api-key",
"token-abc123",
"--port",
str(port),
"--chat-template",
str(chatml_jinja_path),
]
import runpy import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
ctx = torch.multiprocessing.get_context() ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, )) server = ctx.Process(target=server_function, args=(port, ))
server.start() server.start()
MAX_SERVER_START_WAIT_S = 60
client = OpenAI( try:
base_url=f"http://localhost:{port}/v1", client = OpenAI(
api_key="token-abc123", base_url=f"http://localhost:{port}/v1",
) api_key="token-abc123",
now = time.time() )
while True: now = time.time()
try: while True:
completion = client.chat.completions.create( try:
model="facebook/opt-125m", completion = client.chat.completions.create(
messages=[{ model="facebook/opt-125m",
"role": "system", messages=[{
"content": "You are a helpful assistant." "role": "system",
}, { "content": "You are a helpful assistant."
"role": "user", }, {
"content": "Hello!" "role": "user",
}], "content": "Hello!"
temperature=0, }],
) temperature=0,
break )
except OpenAIError as e: break
if "Connection error" in str(e): except OpenAIError as e:
time.sleep(3) if "Connection error" in str(e):
if time.time() - now > MAX_SERVER_START_WAIT_S: time.sleep(3)
raise RuntimeError("Server did not start in time") from e if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
else: msg = "Server did not start in time"
raise e raise RuntimeError(msg) from e
server.kill() else:
raise e
finally:
server.terminate()
generated_text = completion.choices[0].message.content generated_text = completion.choices[0].message.content
assert generated_text is not None
# make sure only the first token is generated # make sure only the first token is generated
rest = generated_text.replace("<s>", "") rest = generated_text.replace("<s>", "")
assert rest == "" assert rest == ""

View File

@ -50,7 +50,7 @@ VLLM_PATH = Path(__file__).parent.parent
class RemoteOpenAIServer: class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds
def __init__( def __init__(
self, self,
@ -85,7 +85,7 @@ class RemoteOpenAIServer:
stdout=sys.stdout, stdout=sys.stdout,
stderr=sys.stderr) stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"), self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S) timeout=self.MAX_START_WAIT_S)
def __enter__(self): def __enter__(self):
return self return self

View File

@ -1,8 +1,9 @@
import codecs import codecs
from dataclasses import dataclass from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast, from pathlib import Path
final) from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
cast, final)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -22,6 +23,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import async_get_and_parse_image from vllm.multimodal.utils import async_get_and_parse_image
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
@ -69,13 +71,17 @@ class ChatMessageParseResult:
mm_futures: List[Awaitable[MultiModalDataDict]] mm_futures: List[Awaitable[MultiModalDataDict]]
def load_chat_template(chat_template: Optional[str]) -> Optional[str]: def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None: if chat_template is None:
return None return None
try: try:
with open(chat_template, "r") as f: with open(chat_template, "r") as f:
resolved_chat_template = f.read() resolved_chat_template = f.read()
except OSError as e: except OSError as e:
if isinstance(chat_template, Path):
raise
JINJA_CHARS = "{}\n" JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS): if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) " msg = (f"The supplied chat template ({chat_template}) "
@ -208,3 +214,28 @@ def parse_chat_messages(
mm_futures.extend(parse_result.mm_futures) mm_futures.extend(parse_result.mm_futures)
return conversation, mm_futures return conversation, mm_futures
def apply_chat_template(
tokenizer: AnyTokenizer,
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")
prompt = tokenizer.apply_chat_template(
conversation=conversation,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
assert isinstance(prompt, str)
return prompt

View File

@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
default=None, default=None,
description=( description=(
"A Jinja template to use for this conversion. " "A Jinja template to use for this conversion. "
"If this is not passed, the model's default chat template will be " "As of transformers v4.44, default chat template is no longer "
"used instead."), "allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
) )
chat_template_kwargs: Optional[Dict[str, Any]] = Field( chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None, default=None,

View File

@ -10,6 +10,7 @@ from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
load_chat_template, load_chat_template,
parse_chat_messages) parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
@ -99,16 +100,15 @@ class OpenAIServingChat(OpenAIServing):
tool.model_dump() for tool in request.tools tool.model_dump() for tool in request.tools
] ]
prompt = tokenizer.apply_chat_template( prompt = apply_chat_template(
tokenizer,
conversation=conversation, conversation=conversation,
tokenize=False, chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt, add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts, tools=tool_dicts,
documents=request.documents, documents=request.documents,
chat_template=request.chat_template or self.chat_template,
**(request.chat_template_kwargs or {}), **(request.chat_template_kwargs or {}),
) )
assert isinstance(prompt, str)
except Exception as e: except Exception as e:
logger.error("Error in applying chat template from request: %s", e) logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -2,7 +2,9 @@ from typing import List, Optional, Union
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages from vllm.entrypoints.chat_utils import (apply_chat_template,
load_chat_template,
parse_chat_messages)
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
logger.warning( logger.warning(
"Multi-modal inputs are ignored during tokenization") "Multi-modal inputs are ignored during tokenization")
prompt = tokenizer.apply_chat_template( prompt = apply_chat_template(
add_generation_prompt=request.add_generation_prompt, tokenizer,
conversation=conversation, conversation=conversation,
tokenize=False, chat_template=self.chat_template,
chat_template=self.chat_template) add_generation_prompt=request.add_generation_prompt,
assert isinstance(prompt, str) )
else: else:
prompt = request.prompt prompt = request.prompt

View File

@ -12,12 +12,12 @@ from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.transformers_utils.tokenizers import BaichuanTokenizer
from vllm.utils import make_async from vllm.utils import make_async
from .tokenizer_group import AnyTokenizer
logger = init_logger(__name__) logger = init_logger(__name__)
def get_cached_tokenizer( def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
"""Get tokenizer with cached properties. """Get tokenizer with cached properties.
This will patch the tokenizer object in place. This will patch the tokenizer object in place.
@ -63,7 +63,7 @@ def get_tokenizer(
revision: Optional[str] = None, revision: Optional[str] = None,
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
**kwargs, **kwargs,
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: ) -> AnyTokenizer:
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope. """Gets a tokenizer for the given model name via HuggingFace or ModelScope.
""" """
if VLLM_USE_MODELSCOPE: if VLLM_USE_MODELSCOPE: