[Frontend] Gracefully handle missing chat template and fix CI failure (#7238)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
7b261092de
commit
66d617e343
@ -1,22 +1,16 @@
|
|||||||
import os
|
|
||||||
import pathlib
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template
|
from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
|
from ..utils import VLLM_PATH
|
||||||
__file__))).parent.parent / "examples/template_chatml.jinja"
|
|
||||||
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
assert chatml_jinja_path.exists()
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
# Define models, templates, and their corresponding expected outputs
|
# Define models, templates, and their corresponding expected outputs
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
MODEL_TEMPLATE_GENERATON_OUTPUT = [
|
||||||
("facebook/opt-125m", None, True,
|
|
||||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
|
||||||
("facebook/opt-125m", None, False,
|
|
||||||
"Hello</s>Hi there!</s>What is the capital of</s>"),
|
|
||||||
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
|
||||||
Hello<|im_end|>
|
Hello<|im_end|>
|
||||||
<|im_start|>assistant
|
<|im_start|>assistant
|
||||||
@ -93,11 +87,12 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
|||||||
add_generation_prompt=add_generation_prompt)
|
add_generation_prompt=add_generation_prompt)
|
||||||
|
|
||||||
# Call the function and get the result
|
# Call the function and get the result
|
||||||
result = tokenizer.apply_chat_template(
|
result = apply_chat_template(
|
||||||
|
tokenizer,
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
tokenize=False,
|
chat_template=mock_request.chat_template or template_content,
|
||||||
add_generation_prompt=mock_request.add_generation_prompt,
|
add_generation_prompt=mock_request.add_generation_prompt,
|
||||||
chat_template=mock_request.chat_template or template_content)
|
)
|
||||||
|
|
||||||
# Test assertion
|
# Test assertion
|
||||||
assert result == expected_output, (
|
assert result == expected_output, (
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
@ -16,7 +18,9 @@ def server():
|
|||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
"2048",
|
"2048",
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--engine-use-ray"
|
"--engine-use-ray",
|
||||||
|
"--chat-template",
|
||||||
|
str(chatml_jinja_path),
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -83,7 +87,7 @@ async def test_single_chat_session(client: openai.AsyncOpenAI):
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.finish_reason == "length"
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.usage == openai.types.CompletionUsage(
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=10, prompt_tokens=13, total_tokens=23)
|
completion_tokens=10, prompt_tokens=55, total_tokens=65)
|
||||||
|
|
||||||
message = choice.message
|
message = choice.message
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
|||||||
@ -9,6 +9,11 @@ from vllm.model_executor.models.opt import OPTForCausalLM
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.utils import get_open_port
|
from vllm.utils import get_open_port
|
||||||
|
|
||||||
|
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
|
|
||||||
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
|
assert chatml_jinja_path.exists()
|
||||||
|
|
||||||
|
|
||||||
class MyOPTForCausalLM(OPTForCausalLM):
|
class MyOPTForCausalLM(OPTForCausalLM):
|
||||||
|
|
||||||
@ -21,12 +26,25 @@ class MyOPTForCausalLM(OPTForCausalLM):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
def server_function(port):
|
def server_function(port: int):
|
||||||
# register our dummy model
|
# register our dummy model
|
||||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||||
sys.argv = ["placeholder.py"] + \
|
|
||||||
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
|
sys.argv = ["placeholder.py"] + [
|
||||||
f"--dtype float32 --api-key token-abc123 --port {port}").split()
|
"--model",
|
||||||
|
"facebook/opt-125m",
|
||||||
|
"--gpu-memory-utilization",
|
||||||
|
"0.10",
|
||||||
|
"--dtype",
|
||||||
|
"float32",
|
||||||
|
"--api-key",
|
||||||
|
"token-abc123",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--chat-template",
|
||||||
|
str(chatml_jinja_path),
|
||||||
|
]
|
||||||
|
|
||||||
import runpy
|
import runpy
|
||||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||||
|
|
||||||
@ -36,35 +54,40 @@ def test_oot_registration_for_api_server():
|
|||||||
ctx = torch.multiprocessing.get_context()
|
ctx = torch.multiprocessing.get_context()
|
||||||
server = ctx.Process(target=server_function, args=(port, ))
|
server = ctx.Process(target=server_function, args=(port, ))
|
||||||
server.start()
|
server.start()
|
||||||
MAX_SERVER_START_WAIT_S = 60
|
|
||||||
client = OpenAI(
|
try:
|
||||||
base_url=f"http://localhost:{port}/v1",
|
client = OpenAI(
|
||||||
api_key="token-abc123",
|
base_url=f"http://localhost:{port}/v1",
|
||||||
)
|
api_key="token-abc123",
|
||||||
now = time.time()
|
)
|
||||||
while True:
|
now = time.time()
|
||||||
try:
|
while True:
|
||||||
completion = client.chat.completions.create(
|
try:
|
||||||
model="facebook/opt-125m",
|
completion = client.chat.completions.create(
|
||||||
messages=[{
|
model="facebook/opt-125m",
|
||||||
"role": "system",
|
messages=[{
|
||||||
"content": "You are a helpful assistant."
|
"role": "system",
|
||||||
}, {
|
"content": "You are a helpful assistant."
|
||||||
"role": "user",
|
}, {
|
||||||
"content": "Hello!"
|
"role": "user",
|
||||||
}],
|
"content": "Hello!"
|
||||||
temperature=0,
|
}],
|
||||||
)
|
temperature=0,
|
||||||
break
|
)
|
||||||
except OpenAIError as e:
|
break
|
||||||
if "Connection error" in str(e):
|
except OpenAIError as e:
|
||||||
time.sleep(3)
|
if "Connection error" in str(e):
|
||||||
if time.time() - now > MAX_SERVER_START_WAIT_S:
|
time.sleep(3)
|
||||||
raise RuntimeError("Server did not start in time") from e
|
if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
|
||||||
else:
|
msg = "Server did not start in time"
|
||||||
raise e
|
raise RuntimeError(msg) from e
|
||||||
server.kill()
|
else:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
server.terminate()
|
||||||
|
|
||||||
generated_text = completion.choices[0].message.content
|
generated_text = completion.choices[0].message.content
|
||||||
|
assert generated_text is not None
|
||||||
# make sure only the first token is generated
|
# make sure only the first token is generated
|
||||||
rest = generated_text.replace("<s>", "")
|
rest = generated_text.replace("<s>", "")
|
||||||
assert rest == ""
|
assert rest == ""
|
||||||
|
|||||||
@ -50,7 +50,7 @@ VLLM_PATH = Path(__file__).parent.parent
|
|||||||
|
|
||||||
class RemoteOpenAIServer:
|
class RemoteOpenAIServer:
|
||||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||||
MAX_SERVER_START_WAIT_S = 120 # wait for server to start for 120 seconds
|
MAX_START_WAIT_S = 120 # wait for server to start for 120 seconds
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -85,7 +85,7 @@ class RemoteOpenAIServer:
|
|||||||
stdout=sys.stdout,
|
stdout=sys.stdout,
|
||||||
stderr=sys.stderr)
|
stderr=sys.stderr)
|
||||||
self._wait_for_server(url=self.url_for("health"),
|
self._wait_for_server(url=self.url_for("health"),
|
||||||
timeout=self.MAX_SERVER_START_WAIT_S)
|
timeout=self.MAX_START_WAIT_S)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
import codecs
|
import codecs
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import (Awaitable, Iterable, List, Optional, Tuple, Union, cast,
|
from pathlib import Path
|
||||||
final)
|
from typing import (Any, Awaitable, Iterable, List, Optional, Tuple, Union,
|
||||||
|
cast, final)
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -22,6 +23,7 @@ from vllm.config import ModelConfig
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalDataDict
|
from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import async_get_and_parse_image
|
from vllm.multimodal.utils import async_get_and_parse_image
|
||||||
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -69,13 +71,17 @@ class ChatMessageParseResult:
|
|||||||
mm_futures: List[Awaitable[MultiModalDataDict]]
|
mm_futures: List[Awaitable[MultiModalDataDict]]
|
||||||
|
|
||||||
|
|
||||||
def load_chat_template(chat_template: Optional[str]) -> Optional[str]:
|
def load_chat_template(
|
||||||
|
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
return None
|
return None
|
||||||
try:
|
try:
|
||||||
with open(chat_template, "r") as f:
|
with open(chat_template, "r") as f:
|
||||||
resolved_chat_template = f.read()
|
resolved_chat_template = f.read()
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
|
if isinstance(chat_template, Path):
|
||||||
|
raise
|
||||||
|
|
||||||
JINJA_CHARS = "{}\n"
|
JINJA_CHARS = "{}\n"
|
||||||
if not any(c in chat_template for c in JINJA_CHARS):
|
if not any(c in chat_template for c in JINJA_CHARS):
|
||||||
msg = (f"The supplied chat template ({chat_template}) "
|
msg = (f"The supplied chat template ({chat_template}) "
|
||||||
@ -208,3 +214,28 @@ def parse_chat_messages(
|
|||||||
mm_futures.extend(parse_result.mm_futures)
|
mm_futures.extend(parse_result.mm_futures)
|
||||||
|
|
||||||
return conversation, mm_futures
|
return conversation, mm_futures
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chat_template(
|
||||||
|
tokenizer: AnyTokenizer,
|
||||||
|
conversation: List[ConversationMessage],
|
||||||
|
chat_template: Optional[str],
|
||||||
|
*,
|
||||||
|
tokenize: bool = False, # Different from HF's default
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
if chat_template is None and tokenizer.chat_template is None:
|
||||||
|
raise ValueError(
|
||||||
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one.")
|
||||||
|
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
conversation=conversation,
|
||||||
|
chat_template=chat_template,
|
||||||
|
tokenize=tokenize,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
assert isinstance(prompt, str)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|||||||
@ -190,8 +190,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
"A Jinja template to use for this conversion. "
|
"A Jinja template to use for this conversion. "
|
||||||
"If this is not passed, the model's default chat template will be "
|
"As of transformers v4.44, default chat template is no longer "
|
||||||
"used instead."),
|
"allowed, so you must provide a chat template if the tokenizer "
|
||||||
|
"does not define one."),
|
||||||
)
|
)
|
||||||
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
|
apply_chat_template,
|
||||||
load_chat_template,
|
load_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
@ -99,16 +100,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
tool.model_dump() for tool in request.tools
|
tool.model_dump() for tool in request.tools
|
||||||
]
|
]
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = apply_chat_template(
|
||||||
|
tokenizer,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
tokenize=False,
|
chat_template=request.chat_template or self.chat_template,
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
tools=tool_dicts,
|
tools=tool_dicts,
|
||||||
documents=request.documents,
|
documents=request.documents,
|
||||||
chat_template=request.chat_template or self.chat_template,
|
|
||||||
**(request.chat_template_kwargs or {}),
|
**(request.chat_template_kwargs or {}),
|
||||||
)
|
)
|
||||||
assert isinstance(prompt, str)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error in applying chat template from request: %s", e)
|
logger.error("Error in applying chat template from request: %s", e)
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
|||||||
@ -2,7 +2,9 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.protocol import AsyncEngineClient
|
from vllm.engine.protocol import AsyncEngineClient
|
||||||
from vllm.entrypoints.chat_utils import load_chat_template, parse_chat_messages
|
from vllm.entrypoints.chat_utils import (apply_chat_template,
|
||||||
|
load_chat_template,
|
||||||
|
parse_chat_messages)
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -70,12 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Multi-modal inputs are ignored during tokenization")
|
"Multi-modal inputs are ignored during tokenization")
|
||||||
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
prompt = apply_chat_template(
|
||||||
add_generation_prompt=request.add_generation_prompt,
|
tokenizer,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
tokenize=False,
|
chat_template=self.chat_template,
|
||||||
chat_template=self.chat_template)
|
add_generation_prompt=request.add_generation_prompt,
|
||||||
assert isinstance(prompt, str)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|
||||||
|
|||||||
@ -12,12 +12,12 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
from vllm.transformers_utils.tokenizers import BaichuanTokenizer
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
|
||||||
|
from .tokenizer_group import AnyTokenizer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_cached_tokenizer(
|
def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer:
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
|
||||||
"""Get tokenizer with cached properties.
|
"""Get tokenizer with cached properties.
|
||||||
|
|
||||||
This will patch the tokenizer object in place.
|
This will patch the tokenizer object in place.
|
||||||
@ -63,7 +63,7 @@ def get_tokenizer(
|
|||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
download_dir: Optional[str] = None,
|
download_dir: Optional[str] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
) -> AnyTokenizer:
|
||||||
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
|
"""Gets a tokenizer for the given model name via HuggingFace or ModelScope.
|
||||||
"""
|
"""
|
||||||
if VLLM_USE_MODELSCOPE:
|
if VLLM_USE_MODELSCOPE:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user