From 63e39937f990818e2f22a9b821a4aa22387057a7 Mon Sep 17 00:00:00 2001 From: xendo Date: Thu, 3 Oct 2024 20:02:07 +0200 Subject: [PATCH] [Frontend] [Neuron] Parse literals out of override-neuron-config (#8959) Co-authored-by: Jerzy Zagorski --- tests/engine/test_arg_utils.py | 48 ++++++++++++++++++++++++---------- vllm/engine/arg_utils.py | 9 +++---- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 360ac1bf..f7dc167f 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg): nullable_kvs(arg) -@pytest.mark.parametrize(("arg", "expected"), [ - (None, None), - ("{}", {}), - ('{"num_crops": 4}', { - "num_crops": 4 - }), - ('{"foo": {"bar": "baz"}}', { - "foo": { - "bar": "baz" - } - }), +# yapf: disable +@pytest.mark.parametrize(("arg", "expected", "option"), [ + (None, None, "mm-processor-kwargs"), + ("{}", {}, "mm-processor-kwargs"), + ( + '{"num_crops": 4}', + { + "num_crops": 4 + }, + "mm-processor-kwargs" + ), + ( + '{"foo": {"bar": "baz"}}', + { + "foo": + { + "bar": "baz" + } + }, + "mm-processor-kwargs" + ), + ( + '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', + { + "cast_logits_dtype": "bfloat16", + "sequence_parallel_norm": True, + "sequence_parallel_norm_threshold": 2048, + }, + "override-neuron-config" + ), ]) -def test_mm_processor_kwargs_prompt_parser(arg, expected): +# yapf: enable +def test_composite_arg_parser(arg, expected, option): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) else: - args = parser.parse_args(["--mm-processor-kwargs", arg]) - assert args.mm_processor_kwargs == expected + args = parser.parse_args([f"--{option}", arg]) + assert getattr(args, option.replace("-", "_")) == expected diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 097fe7c0..81baab3f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -800,13 +800,10 @@ class EngineArgs: "lower performance.") parser.add_argument( '--override-neuron-config', - type=lambda configs: { - str(key): value - for key, value in - (config.split(':') for config in configs.split(',')) - }, + type=json.loads, default=None, - help="override or set neuron device configuration.") + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") parser.add_argument( '--scheduling-policy',