[BugFix] Fix hermes tool parser output error stream arguments in some cases (#10395) (#10398)

Signed-off-by: xiyuan lee <lixiyuan@haier.com>
This commit is contained in:
COSMOPlat 2024-11-19 21:42:50 +08:00 committed by GitHub
parent b4614656b8
commit f028dff33d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -12,8 +12,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.entrypoints.openai.tool_parsers.utils import (
extract_intermediate_diff)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
@ -190,8 +188,11 @@ class Hermes2ProToolParser(ToolParser):
diff = self.prev_tool_call_arr[self.current_tool_id].get(
"arguments")
if diff:
diff = json.dumps(diff).replace(
self.streamed_args_for_tool[self.current_tool_id], "")
diff = diff.encode('utf-8').decode(
'unicode_escape') if diff is str else diff
diff = json.dumps(
diff, ensure_ascii=False
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
logger.debug(
"Finishing tool and found diff that had not "
"been streamed yet: %s", diff)
@ -307,22 +308,20 @@ class Hermes2ProToolParser(ToolParser):
# last case -- we have an update to existing arguments.
elif cur_arguments and prev_arguments:
if isinstance(delta_text, str) and len(delta_text.rstrip(
)) >= 1 and delta_text.rstrip()[-1] == '}':
delta_text = delta_text.rstrip()[:-1]
logger.debug("got diff %s", delta_text)
cur_args_json = json.dumps(cur_arguments)
prev_args_json = json.dumps(prev_arguments)
logger.debug("Searching for diff between\n%s", cur_args_json)
logger.debug("and\n%s", prev_args_json)
argument_diff = extract_intermediate_diff(
cur_args_json, prev_args_json)
logger.debug("got argument diff %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).model_dump(
arguments=delta_text).model_dump(
exclude_none=True))
])
self.streamed_args_for_tool[self.current_tool_id] \
+= argument_diff
+= delta_text
# handle saving the state for the current tool into
# the "prev" list for use in diffing for the next iteration