[BugFix] Enforce Mistral ToolCall id constraint when using the Mistral tool call parser (#9020)
This commit is contained in:
parent
01843c89b8
commit
83caf35e08
@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function is not None
|
||||
assert isinstance(tool_call.id, str)
|
||||
assert len(tool_call.id) > 16
|
||||
assert len(tool_call.id) >= 9
|
||||
|
||||
# make sure the weather tool was called correctly
|
||||
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
|
||||
@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
|
||||
if tool_call.id:
|
||||
tool_call_id_count += 1
|
||||
assert (isinstance(tool_call.id, str)
|
||||
and (len(tool_call.id) > 16))
|
||||
and (len(tool_call.id) >= 9))
|
||||
|
||||
# if parts of the function start being streamed
|
||||
if tool_call.function:
|
||||
|
||||
@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
assert tool_calls[0].type == 'function'
|
||||
assert tool_calls[0].function is not None
|
||||
assert isinstance(tool_calls[0].id, str)
|
||||
assert len(tool_calls[0].id) > 16
|
||||
assert len(tool_calls[0].id) >= 9
|
||||
|
||||
# make sure the weather tool was called (classic example) with arguments
|
||||
assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"]
|
||||
@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI):
|
||||
|
||||
assert finish_reason_count == 1
|
||||
assert role_name == 'assistant'
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16)
|
||||
assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9)
|
||||
|
||||
# validate the name and arguments
|
||||
assert function_name == WEATHER_TOOL["function"]["name"]
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import json
|
||||
import re
|
||||
from random import choices
|
||||
from string import ascii_letters, digits
|
||||
from typing import Dict, List, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -19,6 +22,19 @@ from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ALPHANUMERIC = ascii_letters + digits
|
||||
|
||||
|
||||
class MistralToolCall(ToolCall):
|
||||
id: str = Field(
|
||||
default_factory=lambda: MistralToolCall.generate_random_id())
|
||||
|
||||
@staticmethod
|
||||
def generate_random_id():
|
||||
# Mistral Tool Call Ids must be alphanumeric with a maximum length of 9.
|
||||
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
@ -71,8 +87,8 @@ class MistralToolParser(ToolParser):
|
||||
# load the JSON, and then use it to build the Function and
|
||||
# Tool Call
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
tool_calls: List[ToolCall] = [
|
||||
ToolCall(
|
||||
tool_calls: List[MistralToolCall] = [
|
||||
MistralToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user