[BugFix] Enforce Mistral ToolCall id constraint when using the Mistral tool call parser (#9020)
This commit is contained in:
parent
01843c89b8
commit
83caf35e08
@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
|||||||
assert tool_call.type == "function"
|
assert tool_call.type == "function"
|
||||||
assert tool_call.function is not None
|
assert tool_call.function is not None
|
||||||
assert isinstance(tool_call.id, str)
|
assert isinstance(tool_call.id, str)
|
||||||
assert len(tool_call.id) > 16
|
assert len(tool_call.id) >= 9
|
||||||
|
|
||||||
# make sure the weather tool was called correctly
|
# make sure the weather tool was called correctly
|
||||||
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
||||||
@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
|||||||
if tool_call.id:
|
if tool_call.id:
|
||||||
tool_call_id_count += 1
|
tool_call_id_count += 1
|
||||||
assert (isinstance(tool_call.id, str)
|
assert (isinstance(tool_call.id, str)
|
||||||
and (len(tool_call.id) > 16))
|
and (len(tool_call.id) >= 9))
|
||||||
|
|
||||||
# if parts of the function start being streamed
|
# if parts of the function start being streamed
|
||||||
if tool_call.function:
|
if tool_call.function:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
|||||||
assert tool_calls[0].type == 'function'
|
assert tool_calls[0].type == 'function'
|
||||||
assert tool_calls[0].function is not None
|
assert tool_calls[0].function is not None
|
||||||
assert isinstance(tool_calls[0].id, str)
|
assert isinstance(tool_calls[0].id, str)
|
||||||
assert len(tool_calls[0].id) > 16
|
assert len(tool_calls[0].id) >= 9
|
||||||
|
|
||||||
# make sure the weather tool was called (classic example) with arguments
|
# make sure the weather tool was called (classic example) with arguments
|
||||||
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
|
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
|
||||||
@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
|||||||
|
|
||||||
assert finish_reason_count == 1
|
assert finish_reason_count == 1
|
||||||
assert role_name == 'assistant'
|
assert role_name == 'assistant'
|
||||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
|
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
|
||||||
|
|
||||||
# validate the name and arguments
|
# validate the name and arguments
|
||||||
assert function_name == WEATHER_TOOL["function"]["name"]
|
assert function_name == WEATHER_TOOL["function"]["name"]
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from random import choices
|
||||||
|
from string import ascii_letters, digits
|
||||||
from typing import Dict, List, Sequence, Union
|
from typing import Dict, List, Sequence, Union
|
||||||
|
|
||||||
import partial_json_parser
|
import partial_json_parser
|
||||||
from partial_json_parser.core.options import Allow
|
from partial_json_parser.core.options import Allow
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||||
DeltaToolCall,
|
DeltaToolCall,
|
||||||
@ -19,6 +22,19 @@ from vllm.utils import random_uuid
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
ALPHANUMERIC = ascii_letters + digits
|
||||||
|
|
||||||
|
|
||||||
|
class MistralToolCall(ToolCall):
|
||||||
|
id: str = Field(
|
||||||
|
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_random_id():
|
||||||
|
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||||
|
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||||
|
return "".join(choices(ALPHANUMERIC, k=9))
|
||||||
|
|
||||||
|
|
||||||
class MistralToolParser(ToolParser):
|
class MistralToolParser(ToolParser):
|
||||||
"""
|
"""
|
||||||
@ -71,8 +87,8 @@ class MistralToolParser(ToolParser):
|
|||||||
# load the JSON, and then use it to build the Function and
|
# load the JSON, and then use it to build the Function and
|
||||||
# Tool Call
|
# Tool Call
|
||||||
function_call_arr = json.loads(raw_tool_call)
|
function_call_arr = json.loads(raw_tool_call)
|
||||||
tool_calls: List[ToolCall] = [
|
tool_calls: List[MistralToolCall] = [
|
||||||
ToolCall(
|
MistralToolCall(
|
||||||
type="function",
|
type="function",
|
||||||
function=FunctionCall(
|
function=FunctionCall(
|
||||||
name=raw_function_call["name"],
|
name=raw_function_call["name"],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user