[Tool parsing] Improve / correct mistral tool parsing (#10333)
This commit is contained in:
parent
554af9228d
commit
11cd1ae6ad
@ -2,9 +2,13 @@
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
||||
MistralToolParser)
|
||||
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
@ -58,17 +62,69 @@ TOOLS = [{
|
||||
},
|
||||
"required": ["city", "state", "unit"]
|
||||
}
|
||||
},
|
||||
}, {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rewrite",
|
||||
"description": "Rewrites text",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The input text to rewrite."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
MSGS = [{
|
||||
"role":
|
||||
"user",
|
||||
"content": ("Can you tell me what the temperate"
|
||||
" will be in Dallas, in fahrenheit?")
|
||||
}]
|
||||
EXPECTED_FUNC_CALL = (
|
||||
'[{"name": "get_current_weather", "arguments": '
|
||||
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
|
||||
MSGS = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are an assistant."
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"assistant",
|
||||
"content":
|
||||
"",
|
||||
"tool_calls": [{
|
||||
"id": "bbc5b7ede",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name":
|
||||
"rewrite",
|
||||
"arguments":
|
||||
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
|
||||
}
|
||||
}]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content":
|
||||
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
|
||||
"tool_call_id": "bbc5b7ede",
|
||||
"name": "rewrite"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "---\n\nMy English needs improving, maybe I make errors"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": ("Can you tell me what the temperate"
|
||||
" will be in Dallas, in fahrenheit?")
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@ -175,8 +231,23 @@ def test_mistral_function_calling(
|
||||
tokenizer_mode="mistral",
|
||||
config_format="mistral",
|
||||
load_format="mistral") as vllm_model:
|
||||
outputs = vllm_model.model.chat(MSGS,
|
||||
|
||||
msgs = copy.deepcopy(MSGS)
|
||||
outputs = vllm_model.model.chat(msgs,
|
||||
tools=TOOLS,
|
||||
sampling_params=SAMPLING_PARAMS)
|
||||
|
||||
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
|
||||
tokenizer = vllm_model.model.get_tokenizer()
|
||||
tool_parser = MistralToolParser(tokenizer)
|
||||
|
||||
model_output = outputs[0].outputs[0].text.strip()
|
||||
assert model_output.startswith(tool_parser.bot_token), model_output
|
||||
parsed_message = tool_parser.extract_tool_calls(model_output, None)
|
||||
|
||||
assert parsed_message.tools_called
|
||||
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
|
||||
assert parsed_message.tool_calls[
|
||||
0].function.name == "get_current_weather"
|
||||
assert parsed_message.tool_calls[
|
||||
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
||||
assert parsed_message.content is None
|
||||
|
||||
@ -30,6 +30,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||
from vllm.utils import iterate_with_cancellation
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -127,41 +128,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
"tool_choice = \"required\" is not supported!")
|
||||
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
# declared as iterables are replaced in in the instances by
|
||||
# pydantic-core ValidatorIterator instance. In particular, this
|
||||
# affects tool_calls defined in ChatCompletionAssistantMessageParam
|
||||
# model:
|
||||
# see:
|
||||
# - https://github.com/pydantic/pydantic/issues/9467
|
||||
# As a result, tool_calls from assistant messages are never
|
||||
# deserialized in the request object if the tool_calls iterator is
|
||||
# not consumed. This affect messages passed to the MistralTokenizer
|
||||
# since no chat template is applied and therefore the tools_calls
|
||||
# iterator is not directly consumed.
|
||||
# Issue is tracked on Pydantic side, with resolution planned for
|
||||
# v2.11 release. In the meantime, the official workaround is to
|
||||
# consume the iterator so the tool_calls are correctly deserialized
|
||||
# in the OpenAI ChatCompletionAssistantMessageParam object
|
||||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
|
||||
# Official Pydantic Issues:
|
||||
# - https://github.com/pydantic/pydantic/issues/9541
|
||||
# TODO: remove when pydantic v2.11 is released
|
||||
# because of issues with pydantic we need to potentially
|
||||
# re-serialize the tool_calls field of the request
|
||||
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls_validator = message.get(
|
||||
"tool_calls", ().__iter__())
|
||||
validated_tool_calls = []
|
||||
while True:
|
||||
try:
|
||||
tool_call = next(
|
||||
tool_calls_validator) # type: ignore
|
||||
validated_tool_calls.append(tool_call)
|
||||
except StopIteration:
|
||||
break
|
||||
request.messages[i][
|
||||
"tool_calls"] = validated_tool_calls
|
||||
maybe_serialize_tool_calls(request)
|
||||
|
||||
if (request.tool_choice == "auto" and
|
||||
not (self.enable_auto_tools and tool_parser is not None)
|
||||
|
||||
@ -62,7 +62,7 @@ class MistralToolParser(ToolParser):
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
if self.bot_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral Tool Parser could not locate the tool call token in "
|
||||
@ -84,16 +84,25 @@ class MistralToolParser(ToolParser):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
# first remove the BOT token
|
||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||
|
||||
try:
|
||||
|
||||
# use a regex to find the tool call. remove the BOT token
|
||||
# and make sure to replace single quotes with double quotes
|
||||
raw_tool_call = self.tool_call_regex.findall(
|
||||
model_output.replace(self.bot_token, ""))[0]
|
||||
# we first try to directly load the json as parsing very nested
|
||||
# jsons is difficult
|
||||
try:
|
||||
function_call_arr = json.loads(tool_content)
|
||||
except json.JSONDecodeError:
|
||||
# use a regex to find the part corresponding to the tool call.
|
||||
# NOTE: This use case should not happen if the model is trained
|
||||
# correctly. It's a easy possible fix so it's included, but
|
||||
# can be brittle for very complex / highly nested tool calls
|
||||
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
tool_calls: List[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
@ -116,7 +125,7 @@ class MistralToolParser(ToolParser):
|
||||
# return information to just treat the tool call as regular JSON
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
content=tool_content)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from .mistral import MistralTokenizer
|
||||
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||
|
||||
__all__ = ["MistralTokenizer"]
|
||||
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
import huggingface_hub
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
# yapf: disable
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer)
|
||||
@ -29,6 +30,43 @@ class Encoding:
|
||||
input_ids: List[int]
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||
# Credits go to: @gcalmettes
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
# declared as iterables are replaced in in the instances by
|
||||
# pydantic-core ValidatorIterator instance. In particular, this
|
||||
# affects tool_calls defined in ChatCompletionAssistantMessageParam
|
||||
# model:
|
||||
# see:
|
||||
# - https://github.com/pydantic/pydantic/issues/9467
|
||||
# As a result, tool_calls from assistant messages are never
|
||||
# deserialized in the request object if the tool_calls iterator is
|
||||
# not consumed. This affect messages passed to the MistralTokenizer
|
||||
# since no chat template is applied and therefore the tools_calls
|
||||
# iterator is not directly consumed.
|
||||
# Issue is tracked on Pydantic side, with resolution planned for
|
||||
# v2.11 release. In the meantime, the official workaround is to
|
||||
# consume the iterator so the tool_calls are correctly deserialized
|
||||
# in the OpenAI ChatCompletionAssistantMessageParam object
|
||||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
|
||||
# Official Pydantic Issues:
|
||||
# - https://github.com/pydantic/pydantic/issues/9541
|
||||
# TODO: remove when pydantic v2.11 is released
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == 'assistant':
|
||||
tool_calls_validator = message.get("tool_calls", ().__iter__())
|
||||
validated_tool_calls = []
|
||||
while True:
|
||||
try:
|
||||
tool_call = next(tool_calls_validator) # type: ignore
|
||||
validated_tool_calls.append(tool_call)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
@ -222,7 +260,8 @@ class MistralTokenizer:
|
||||
if self.is_tekken:
|
||||
tokens = [
|
||||
t for t in tokens
|
||||
if t not in self.tokenizer._all_special_tokens
|
||||
if (t is SpecialTokens.tool_calls
|
||||
or t not in self.tokenizer._all_special_tokens)
|
||||
]
|
||||
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
@ -246,7 +285,27 @@ class MistralTokenizer:
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
else:
|
||||
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
||||
# make sure certain special tokens like Tool calls are
|
||||
# not decoded
|
||||
special_tokens = {SpecialTokens.tool_calls}
|
||||
regular_tokens: List[str] = []
|
||||
decoded_list = []
|
||||
|
||||
for token in tokens:
|
||||
if token in special_tokens:
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens))
|
||||
regular_tokens = []
|
||||
decoded_list.append(token)
|
||||
else:
|
||||
regular_tokens.append(token)
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.decode(regular_tokens)) # type: ignore
|
||||
|
||||
decoded = ''.join(decoded_list)
|
||||
|
||||
return decoded
|
||||
|
||||
@ -274,8 +333,11 @@ class MistralTokenizer:
|
||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||
|
||||
if self.is_tekken:
|
||||
# skip special tokens
|
||||
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
|
||||
# skip special tokens except tool call
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||
]
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user