From 7e65477e5e737927c2f07c913ede0763134504a3 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Fri, 3 May 2024 13:32:21 -0400 Subject: [PATCH] [Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586) --- vllm/engine/arg_utils.py | 32 +++++++++++++++++------------ vllm/entrypoints/openai/cli_args.py | 27 +++++++++++++----------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1c8e1079..78cd0757 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -11,6 +11,12 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import str_to_int_tuple +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" @@ -96,7 +102,7 @@ class EngineArgs: help='Name or path of the huggingface model to use.') parser.add_argument( '--tokenizer', - type=str, + type=nullable_str, default=EngineArgs.tokenizer, help='Name or path of the huggingface tokenizer to use.') parser.add_argument( @@ -105,21 +111,21 @@ class EngineArgs: help='Skip initialization of tokenizer and detokenizer') parser.add_argument( '--revision', - type=str, + type=nullable_str, default=None, help='The specific model version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' 'the default version.') parser.add_argument( '--code-revision', - type=str, + type=nullable_str, default=None, help='The specific revision to use for the model code on ' 'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'commit id. If unspecified, will use the default version.') parser.add_argument( '--tokenizer-revision', - type=str, + type=nullable_str, default=None, help='The specific tokenizer version to use. It can be a branch ' 'name, a tag name, or a commit id. If unspecified, will use ' @@ -136,7 +142,7 @@ class EngineArgs: action='store_true', help='Trust remote code from huggingface.') parser.add_argument('--download-dir', - type=str, + type=nullable_str, default=EngineArgs.download_dir, help='Directory to download and load the weights, ' 'default to the default cache dir of ' @@ -187,7 +193,7 @@ class EngineArgs: 'supported for common inference criteria.') parser.add_argument( '--quantization-param-path', - type=str, + type=nullable_str, default=None, help='Path to the JSON file containing the KV cache ' 'scaling factors. This should generally be supplied, when ' @@ -304,7 +310,7 @@ class EngineArgs: # Quantization settings. parser.add_argument('--quantization', '-q', - type=str, + type=nullable_str, choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.quantization, help='Method used to quantize the weights. If ' @@ -349,7 +355,7 @@ class EngineArgs: 'asynchronous tokenization. Ignored ' 'if tokenizer_pool_size is 0.') parser.add_argument('--tokenizer-pool-extra-config', - type=str, + type=nullable_str, default=EngineArgs.tokenizer_pool_extra_config, help='Extra config for tokenizer pool. ' 'This should be a JSON string that will be ' @@ -404,7 +410,7 @@ class EngineArgs: # Related to Vision-language models such as llava parser.add_argument( '--image-input-type', - type=str, + type=nullable_str, default=None, choices=[ t.name.lower() for t in VisionLanguageConfig.ImageInputType @@ -417,7 +423,7 @@ class EngineArgs: help=('Input id for image token.')) parser.add_argument( '--image-input-shape', - type=str, + type=nullable_str, default=None, help=('The biggest image input shape (worst for memory footprint) ' 'given an input type. Only used for vLLM\'s profile_run.')) @@ -440,7 +446,7 @@ class EngineArgs: parser.add_argument( '--speculative-model', - type=str, + type=nullable_str, default=EngineArgs.speculative_model, help= 'The name of the draft model to be used in speculative decoding.') @@ -454,7 +460,7 @@ class EngineArgs: parser.add_argument( '--speculative-max-model-len', - type=str, + type=int, default=EngineArgs.speculative_max_model_len, help='The maximum sequence length supported by the ' 'draft model. Sequences over this length will skip ' @@ -475,7 +481,7 @@ class EngineArgs: 'decoding.') parser.add_argument('--model-loader-extra-config', - type=str, + type=nullable_str, default=EngineArgs.model_loader_extra_config, help='Extra config for model loader. ' 'This will be passed to the model loader ' diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 16c5b6c0..2b57ab26 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -8,7 +8,7 @@ import argparse import json import ssl -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import LoRAModulePath @@ -25,7 +25,10 @@ class LoRAParserAction(argparse.Action): def make_arg_parser(): parser = argparse.ArgumentParser( description="vLLM OpenAI-Compatible RESTful API server.") - parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument( "--uvicorn-log-level", @@ -49,13 +52,13 @@ def make_arg_parser(): default=["*"], help="allowed headers") parser.add_argument("--api-key", - type=str, + type=nullable_str, default=None, help="If provided, the server will require this key " "to be presented in the header.") parser.add_argument("--served-model-name", nargs="+", - type=str, + type=nullable_str, default=None, help="The model name(s) used in the API. If multiple " "names are provided, the server will respond to any " @@ -65,33 +68,33 @@ def make_arg_parser(): "same as the `--model` argument.") parser.add_argument( "--lora-modules", - type=str, + type=nullable_str, default=None, nargs='+', action=LoRAParserAction, help="LoRA module configurations in the format name=path. " "Multiple modules can be specified.") parser.add_argument("--chat-template", - type=str, + type=nullable_str, default=None, help="The file path to the chat template, " "or the template in single-line form " "for the specified model") parser.add_argument("--response-role", - type=str, + type=nullable_str, default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") parser.add_argument("--ssl-keyfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL key file") parser.add_argument("--ssl-certfile", - type=str, + type=nullable_str, default=None, help="The file path to the SSL cert file") parser.add_argument("--ssl-ca-certs", - type=str, + type=nullable_str, default=None, help="The CA certificates file") parser.add_argument( @@ -102,12 +105,12 @@ def make_arg_parser(): ) parser.add_argument( "--root-path", - type=str, + type=nullable_str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") parser.add_argument( "--middleware", - type=str, + type=nullable_str, action="append", default=[], help="Additional ASGI middleware to apply to the app. "