Signed-off-by: xiyuan lee <lixiyuan@haier.com>
This commit is contained in:
parent
b4614656b8
commit
f028dff33d
@ -12,8 +12,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|||||||
FunctionCall, ToolCall)
|
FunctionCall, ToolCall)
|
||||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||||
ToolParser, ToolParserManager)
|
ToolParser, ToolParserManager)
|
||||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
|
||||||
extract_intermediate_diff)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
@ -190,8 +188,11 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||||
"arguments")
|
"arguments")
|
||||||
if diff:
|
if diff:
|
||||||
diff = json.dumps(diff).replace(
|
diff = diff.encode('utf-8').decode(
|
||||||
self.streamed_args_for_tool[self.current_tool_id], "")
|
'unicode_escape') if diff is str else diff
|
||||||
|
diff = json.dumps(
|
||||||
|
diff, ensure_ascii=False
|
||||||
|
)[len(self.streamed_args_for_tool[self.current_tool_id]):]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Finishing tool and found diff that had not "
|
"Finishing tool and found diff that had not "
|
||||||
"been streamed yet: %s", diff)
|
"been streamed yet: %s", diff)
|
||||||
@ -307,22 +308,20 @@ class Hermes2ProToolParser(ToolParser):
|
|||||||
|
|
||||||
# last case -- we have an update to existing arguments.
|
# last case -- we have an update to existing arguments.
|
||||||
elif cur_arguments and prev_arguments:
|
elif cur_arguments and prev_arguments:
|
||||||
|
if isinstance(delta_text, str) and len(delta_text.rstrip(
|
||||||
|
)) >= 1 and delta_text.rstrip()[-1] == '}':
|
||||||
|
delta_text = delta_text.rstrip()[:-1]
|
||||||
|
|
||||||
|
logger.debug("got diff %s", delta_text)
|
||||||
|
|
||||||
cur_args_json = json.dumps(cur_arguments)
|
|
||||||
prev_args_json = json.dumps(prev_arguments)
|
|
||||||
logger.debug("Searching for diff between\n%s", cur_args_json)
|
|
||||||
logger.debug("and\n%s", prev_args_json)
|
|
||||||
argument_diff = extract_intermediate_diff(
|
|
||||||
cur_args_json, prev_args_json)
|
|
||||||
logger.debug("got argument diff %s", argument_diff)
|
|
||||||
delta = DeltaMessage(tool_calls=[
|
delta = DeltaMessage(tool_calls=[
|
||||||
DeltaToolCall(index=self.current_tool_id,
|
DeltaToolCall(index=self.current_tool_id,
|
||||||
function=DeltaFunctionCall(
|
function=DeltaFunctionCall(
|
||||||
arguments=argument_diff).model_dump(
|
arguments=delta_text).model_dump(
|
||||||
exclude_none=True))
|
exclude_none=True))
|
||||||
])
|
])
|
||||||
self.streamed_args_for_tool[self.current_tool_id] \
|
self.streamed_args_for_tool[self.current_tool_id] \
|
||||||
+= argument_diff
|
+= delta_text
|
||||||
|
|
||||||
# handle saving the state for the current tool into
|
# handle saving the state for the current tool into
|
||||||
# the "prev" list for use in diffing for the next iteration
|
# the "prev" list for use in diffing for the next iteration
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user