[BugFix] Fix hermes tool parser output error stream arguments in some cases (#10395) (#10398)

Signed-off-by: xiyuan lee <lixiyuan@haier.com>
This commit is contained in:
COSMOPlat 2024-11-19 21:42:50 +08:00 committed by GitHub
parent b4614656b8
commit f028dff33d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,8 +12,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
FunctionCall, ToolCall) FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager) ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -190,8 +188,11 @@ class Hermes2ProToolParser(ToolParser):
diff = self.prev_tool_call_arr[self.current_tool_id].get( diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments") "arguments")
if diff: if diff:
diff = json.dumps(diff).replace( diff = diff.encode('utf-8').decode(
self.streamed_args_for_tool[self.current_tool_id], "") 'unicode_escape') if diff is str else diff
diff = json.dumps(
diff, ensure_ascii=False
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
logger.debug( logger.debug(
"Finishing tool and found diff that had not " "Finishing tool and found diff that had not "
"been streamed yet: %s", diff) "been streamed yet: %s", diff)
@ -307,22 +308,20 @@ class Hermes2ProToolParser(ToolParser):
# last case -- we have an update to existing arguments. # last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments: elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip(
)) >= 1 and delta_text.rstrip()[-1] == '}':
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[ delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id, DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall( function=DeltaFunctionCall(
arguments=argument_diff).model_dump( arguments=delta_text).model_dump(
exclude_none=True)) exclude_none=True))
]) ])
self.streamed_args_for_tool[self.current_tool_id] \ self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff += delta_text
# handle saving the state for the current tool into # handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration # the "prev" list for use in diffing for the next iteration