Signed-off-by: xiyuan lee <lixiyuan@haier.com>
This commit is contained in:
parent
b4614656b8
commit
f028dff33d
@ -12,8 +12,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
@ -190,8 +188,11 @@ class Hermes2ProToolParser(ToolParser):
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = json.dumps(diff).replace(
|
||||
self.streamed_args_for_tool[self.current_tool_id], "")
|
||||
diff = diff.encode('utf-8').decode(
|
||||
'unicode_escape') if diff is str else diff
|
||||
diff = json.dumps(
|
||||
diff, ensure_ascii=False
|
||||
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s", diff)
|
||||
@ -307,22 +308,20 @@ class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
||||
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
||||
delta_text = delta_text.rstrip()[:-1]
|
||||
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
logger.debug("Searching for diff between\n%s", cur_args_json)
|
||||
logger.debug("and\n%s", prev_args_json)
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
logger.debug("got argument diff %s", argument_diff)
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
arguments=delta_text).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[self.current_tool_id] \
|
||||
+= argument_diff
|
||||
+= delta_text
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
|
||||
Loading…
Reference in New Issue
Block a user