[Tool parsing] Improve / correct mistral tool parsing (#10333)
This commit is contained in:
parent
554af9228d
commit
11cd1ae6ad
@ -2,9 +2,13 @@
|
|||||||
|
|
||||||
Run `pytest tests/models/test_mistral.py`.
|
Run `pytest tests/models/test_mistral.py`.
|
||||||
"""
|
"""
|
||||||
|
import copy
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
|
||||||
|
MistralToolParser)
|
||||||
|
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
@ -58,17 +62,69 @@ TOOLS = [{
|
|||||||
},
|
},
|
||||||
"required": ["city", "state", "unit"]
|
"required": ["city", "state", "unit"]
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "rewrite",
|
||||||
|
"description": "Rewrites text",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": [],
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The input text to rewrite."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}]
|
}]
|
||||||
MSGS = [{
|
MSGS = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are an assistant."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content":
|
||||||
|
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role":
|
||||||
|
"assistant",
|
||||||
|
"content":
|
||||||
|
"",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "bbc5b7ede",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name":
|
||||||
|
"rewrite",
|
||||||
|
"arguments":
|
||||||
|
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content":
|
||||||
|
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
|
||||||
|
"tool_call_id": "bbc5b7ede",
|
||||||
|
"name": "rewrite"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "---\n\nMy English needs improving, maybe I make errors"
|
||||||
|
},
|
||||||
|
{
|
||||||
"role":
|
"role":
|
||||||
"user",
|
"user",
|
||||||
"content": ("Can you tell me what the temperate"
|
"content": ("Can you tell me what the temperate"
|
||||||
" will be in Dallas, in fahrenheit?")
|
" will be in Dallas, in fahrenheit?")
|
||||||
}]
|
}
|
||||||
EXPECTED_FUNC_CALL = (
|
]
|
||||||
'[{"name": "get_current_weather", "arguments": '
|
|
||||||
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@ -175,8 +231,23 @@ def test_mistral_function_calling(
|
|||||||
tokenizer_mode="mistral",
|
tokenizer_mode="mistral",
|
||||||
config_format="mistral",
|
config_format="mistral",
|
||||||
load_format="mistral") as vllm_model:
|
load_format="mistral") as vllm_model:
|
||||||
outputs = vllm_model.model.chat(MSGS,
|
|
||||||
|
msgs = copy.deepcopy(MSGS)
|
||||||
|
outputs = vllm_model.model.chat(msgs,
|
||||||
tools=TOOLS,
|
tools=TOOLS,
|
||||||
sampling_params=SAMPLING_PARAMS)
|
sampling_params=SAMPLING_PARAMS)
|
||||||
|
|
||||||
assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
|
tokenizer = vllm_model.model.get_tokenizer()
|
||||||
|
tool_parser = MistralToolParser(tokenizer)
|
||||||
|
|
||||||
|
model_output = outputs[0].outputs[0].text.strip()
|
||||||
|
assert model_output.startswith(tool_parser.bot_token), model_output
|
||||||
|
parsed_message = tool_parser.extract_tool_calls(model_output, None)
|
||||||
|
|
||||||
|
assert parsed_message.tools_called
|
||||||
|
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
|
||||||
|
assert parsed_message.tool_calls[
|
||||||
|
0].function.name == "get_current_weather"
|
||||||
|
assert parsed_message.tool_calls[
|
||||||
|
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
|
||||||
|
assert parsed_message.content is None
|
||||||
|
|||||||
@ -30,6 +30,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
|
|||||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||||
from vllm.utils import iterate_with_cancellation
|
from vllm.utils import iterate_with_cancellation
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -127,41 +128,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"tool_choice = \"required\" is not supported!")
|
"tool_choice = \"required\" is not supported!")
|
||||||
|
|
||||||
# NOTE: There is currently a bug in pydantic where attributes
|
# because of issues with pydantic we need to potentially
|
||||||
# declared as iterables are replaced in in the instances by
|
# re-serialize the tool_calls field of the request
|
||||||
# pydantic-core ValidatorIterator instance. In particular, this
|
# for more info: see comment in `maybe_serialize_tool_calls`
|
||||||
# affects tool_calls defined in ChatCompletionAssistantMessageParam
|
|
||||||
# model:
|
|
||||||
# see:
|
|
||||||
# - https://github.com/pydantic/pydantic/issues/9467
|
|
||||||
# As a result, tool_calls from assistant messages are never
|
|
||||||
# deserialized in the request object if the tool_calls iterator is
|
|
||||||
# not consumed. This affect messages passed to the MistralTokenizer
|
|
||||||
# since no chat template is applied and therefore the tools_calls
|
|
||||||
# iterator is not directly consumed.
|
|
||||||
# Issue is tracked on Pydantic side, with resolution planned for
|
|
||||||
# v2.11 release. In the meantime, the official workaround is to
|
|
||||||
# consume the iterator so the tool_calls are correctly deserialized
|
|
||||||
# in the OpenAI ChatCompletionAssistantMessageParam object
|
|
||||||
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
|
|
||||||
# Official Pydantic Issues:
|
|
||||||
# - https://github.com/pydantic/pydantic/issues/9541
|
|
||||||
# TODO: remove when pydantic v2.11 is released
|
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
for i, message in enumerate(request.messages):
|
maybe_serialize_tool_calls(request)
|
||||||
if message.get("role") == 'assistant':
|
|
||||||
tool_calls_validator = message.get(
|
|
||||||
"tool_calls", ().__iter__())
|
|
||||||
validated_tool_calls = []
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
tool_call = next(
|
|
||||||
tool_calls_validator) # type: ignore
|
|
||||||
validated_tool_calls.append(tool_call)
|
|
||||||
except StopIteration:
|
|
||||||
break
|
|
||||||
request.messages[i][
|
|
||||||
"tool_calls"] = validated_tool_calls
|
|
||||||
|
|
||||||
if (request.tool_choice == "auto" and
|
if (request.tool_choice == "auto" and
|
||||||
not (self.enable_auto_tools and tool_parser is not None)
|
not (self.enable_auto_tools and tool_parser is not None)
|
||||||
|
|||||||
@ -62,7 +62,7 @@ class MistralToolParser(ToolParser):
|
|||||||
] # map what has been streamed for each tool so far to a list
|
] # map what has been streamed for each tool so far to a list
|
||||||
self.bot_token = "[TOOL_CALLS]"
|
self.bot_token = "[TOOL_CALLS]"
|
||||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
if self.bot_token_id is None:
|
if self.bot_token_id is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Mistral Tool Parser could not locate the tool call token in "
|
"Mistral Tool Parser could not locate the tool call token in "
|
||||||
@ -84,16 +84,25 @@ class MistralToolParser(ToolParser):
|
|||||||
return ExtractedToolCallInformation(tools_called=False,
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
content=model_output)
|
content=model_output)
|
||||||
|
|
||||||
|
# first remove the BOT token
|
||||||
|
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# use a regex to find the tool call. remove the BOT token
|
# we first try to directly load the json as parsing very nested
|
||||||
# and make sure to replace single quotes with double quotes
|
# jsons is difficult
|
||||||
raw_tool_call = self.tool_call_regex.findall(
|
try:
|
||||||
model_output.replace(self.bot_token, ""))[0]
|
function_call_arr = json.loads(tool_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
# load the JSON, and then use it to build the Function and
|
# use a regex to find the part corresponding to the tool call.
|
||||||
# Tool Call
|
# NOTE: This use case should not happen if the model is trained
|
||||||
|
# correctly. It's a easy possible fix so it's included, but
|
||||||
|
# can be brittle for very complex / highly nested tool calls
|
||||||
|
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
|
||||||
function_call_arr = json.loads(raw_tool_call)
|
function_call_arr = json.loads(raw_tool_call)
|
||||||
|
|
||||||
|
# Tool Call
|
||||||
tool_calls: List[MistralToolCall] = [
|
tool_calls: List[MistralToolCall] = [
|
||||||
MistralToolCall(
|
MistralToolCall(
|
||||||
type="function",
|
type="function",
|
||||||
@ -116,7 +125,7 @@ class MistralToolParser(ToolParser):
|
|||||||
# return information to just treat the tool call as regular JSON
|
# return information to just treat the tool call as regular JSON
|
||||||
return ExtractedToolCallInformation(tools_called=False,
|
return ExtractedToolCallInformation(tools_called=False,
|
||||||
tool_calls=[],
|
tool_calls=[],
|
||||||
content=model_output)
|
content=tool_content)
|
||||||
|
|
||||||
def extract_tool_calls_streaming(
|
def extract_tool_calls_streaming(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
from .mistral import MistralTokenizer
|
from .mistral import MistralTokenizer, maybe_serialize_tool_calls
|
||||||
|
|
||||||
__all__ = ["MistralTokenizer"]
|
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
|||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from huggingface_hub import HfApi, hf_hub_download
|
from huggingface_hub import HfApi, hf_hub_download
|
||||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||||
|
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from mistral_common.tokens.tokenizers.mistral import (
|
from mistral_common.tokens.tokenizers.mistral import (
|
||||||
MistralTokenizer as PublicMistralTokenizer)
|
MistralTokenizer as PublicMistralTokenizer)
|
||||||
@ -29,6 +30,43 @@ class Encoding:
|
|||||||
input_ids: List[int]
|
input_ids: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_serialize_tool_calls(request: ChatCompletionRequest):
|
||||||
|
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||||
|
# Credits go to: @gcalmettes
|
||||||
|
# NOTE: There is currently a bug in pydantic where attributes
|
||||||
|
# declared as iterables are replaced in in the instances by
|
||||||
|
# pydantic-core ValidatorIterator instance. In particular, this
|
||||||
|
# affects tool_calls defined in ChatCompletionAssistantMessageParam
|
||||||
|
# model:
|
||||||
|
# see:
|
||||||
|
# - https://github.com/pydantic/pydantic/issues/9467
|
||||||
|
# As a result, tool_calls from assistant messages are never
|
||||||
|
# deserialized in the request object if the tool_calls iterator is
|
||||||
|
# not consumed. This affect messages passed to the MistralTokenizer
|
||||||
|
# since no chat template is applied and therefore the tools_calls
|
||||||
|
# iterator is not directly consumed.
|
||||||
|
# Issue is tracked on Pydantic side, with resolution planned for
|
||||||
|
# v2.11 release. In the meantime, the official workaround is to
|
||||||
|
# consume the iterator so the tool_calls are correctly deserialized
|
||||||
|
# in the OpenAI ChatCompletionAssistantMessageParam object
|
||||||
|
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
|
||||||
|
# Official Pydantic Issues:
|
||||||
|
# - https://github.com/pydantic/pydantic/issues/9541
|
||||||
|
# TODO: remove when pydantic v2.11 is released
|
||||||
|
for i, message in enumerate(request.messages):
|
||||||
|
if message.get("role") == 'assistant':
|
||||||
|
tool_calls_validator = message.get("tool_calls", ().__iter__())
|
||||||
|
validated_tool_calls = []
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tool_call = next(tool_calls_validator) # type: ignore
|
||||||
|
validated_tool_calls.append(tool_call)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||||
|
|
||||||
|
|
||||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
|
||||||
repo_cache = os.path.join(
|
repo_cache = os.path.join(
|
||||||
huggingface_hub.constants.HF_HUB_CACHE,
|
huggingface_hub.constants.HF_HUB_CACHE,
|
||||||
@ -222,7 +260,8 @@ class MistralTokenizer:
|
|||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
tokens = [
|
tokens = [
|
||||||
t for t in tokens
|
t for t in tokens
|
||||||
if t not in self.tokenizer._all_special_tokens
|
if (t is SpecialTokens.tool_calls
|
||||||
|
or t not in self.tokenizer._all_special_tokens)
|
||||||
]
|
]
|
||||||
|
|
||||||
if any(isinstance(t, bytes) for t in tokens):
|
if any(isinstance(t, bytes) for t in tokens):
|
||||||
@ -246,7 +285,27 @@ class MistralTokenizer:
|
|||||||
else:
|
else:
|
||||||
decoded = "".join(tokens)
|
decoded = "".join(tokens)
|
||||||
else:
|
else:
|
||||||
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
|
# make sure certain special tokens like Tool calls are
|
||||||
|
# not decoded
|
||||||
|
special_tokens = {SpecialTokens.tool_calls}
|
||||||
|
regular_tokens: List[str] = []
|
||||||
|
decoded_list = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if token in special_tokens:
|
||||||
|
if regular_tokens:
|
||||||
|
decoded_list.append(
|
||||||
|
self.tokenizer.decode(regular_tokens))
|
||||||
|
regular_tokens = []
|
||||||
|
decoded_list.append(token)
|
||||||
|
else:
|
||||||
|
regular_tokens.append(token)
|
||||||
|
|
||||||
|
if regular_tokens:
|
||||||
|
decoded_list.append(
|
||||||
|
self.decode(regular_tokens)) # type: ignore
|
||||||
|
|
||||||
|
decoded = ''.join(decoded_list)
|
||||||
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
@ -274,8 +333,11 @@ class MistralTokenizer:
|
|||||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||||
|
|
||||||
if self.is_tekken:
|
if self.is_tekken:
|
||||||
# skip special tokens
|
# skip special tokens except tool call
|
||||||
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
|
ids = [
|
||||||
|
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
||||||
|
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||||
|
]
|
||||||
|
|
||||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user