From 7560ae5cafbae3af9967ac7dc979cb31a40fc572 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 21 Nov 2024 12:30:42 -0800 Subject: [PATCH] [8/N] enable cli flag without a space (#10529) Signed-off-by: youkaichao --- tests/compile/test_basic_correctness.py | 4 ++-- tests/engine/test_arg_utils.py | 28 +++++++++++++++++++++++++ tests/tpu/test_custom_dispatcher.py | 9 ++++---- vllm/engine/arg_utils.py | 5 ++++- vllm/utils.py | 4 ++++ 5 files changed, 43 insertions(+), 7 deletions(-) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index c0db2e78..b7170886 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -103,7 +103,7 @@ def test_compile_correctness(test_setting: TestSetting): CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE, ]: - all_args.append(final_args + ["-O", str(level)]) + all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) # inductor will change the output, so we only compare if the output @@ -121,7 +121,7 @@ def test_compile_correctness(test_setting: TestSetting): CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, ]: - all_args.append(final_args + ["-O", str(level)]) + all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: # "DYNAMO_ONCE" will always use fullgraph diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 7b1be5a9..5b0e76fe 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_compilation_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # default value + args = parser.parse_args([]) + assert args.compilation_config is None + + # set to O3 + args = parser.parse_args(["-O3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (space) + args = parser.parse_args(["-O", "3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (equals) + args = parser.parse_args(["-O=3"]) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(["--compilation-config", '{"level": 3}']) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(['--compilation-config={"level": 3}']) + assert args.compilation_config.level == 3 + + def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([ diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index df348258..bb1379de 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -13,9 +13,10 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000" def test_custom_dispatcher(): compare_two_settings( "google/gemma-2b", - arg1=["--enforce-eager", "-O", - str(CompilationLevel.DYNAMO_ONCE)], - arg2=["--enforce-eager", "-O", - str(CompilationLevel.DYNAMO_AS_IS)], + arg1=[ + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_ONCE}", + ], + arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"], env1={}, env2={}) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9288cd22..88862a18 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -882,7 +882,10 @@ class EngineArgs: 'testing only. level 3 is the recommended level ' 'for production.\n' 'To specify the full compilation config, ' - 'use a JSON string.') + 'use a JSON string.\n' + 'Following the convention of traditional ' + 'compilers, using -O without space is also ' + 'supported. -O3 is equivalent to -O 3.') return parser diff --git a/vllm/utils.py b/vllm/utils.py index 424e7d09..67b2629e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1192,6 +1192,10 @@ class FlexibleArgumentParser(argparse.ArgumentParser): else: processed_args.append('--' + arg[len('--'):].replace('_', '-')) + elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append('-O') + processed_args.append(arg[2:]) else: processed_args.append(arg)