[Bugfix] Allow "None" or "" to be passed to CLI for string args that default to None (#4586)

This commit is contained in:
Michael Goin 2024-05-03 13:32:21 -04:00 committed by GitHub
parent 3521ba4f25
commit 7e65477e5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 25 deletions

View File

@ -11,6 +11,12 @@ from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
def nullable_str(val: str):
if not val or val == "None":
return None
return val
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
@ -96,7 +102,7 @@ class EngineArgs:
help='Name or path of the huggingface model to use.') help='Name or path of the huggingface model to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=str, type=nullable_str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use.') help='Name or path of the huggingface tokenizer to use.')
parser.add_argument( parser.add_argument(
@ -105,21 +111,21 @@ class EngineArgs:
help='Skip initialization of tokenizer and detokenizer') help='Skip initialization of tokenizer and detokenizer')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=nullable_str,
default=None, default=None,
help='The specific tokenizer version to use. It can be a branch ' help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
@ -136,7 +142,7 @@ class EngineArgs:
action='store_true', action='store_true',
help='Trust remote code from huggingface.') help='Trust remote code from huggingface.')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=nullable_str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='Directory to download and load the weights, ' help='Directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
@ -187,7 +193,7 @@ class EngineArgs:
'supported for common inference criteria.') 'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=nullable_str,
default=None, default=None,
help='Path to the JSON file containing the KV cache ' help='Path to the JSON file containing the KV cache '
'scaling factors. This should generally be supplied, when ' 'scaling factors. This should generally be supplied, when '
@ -304,7 +310,7 @@ class EngineArgs:
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=nullable_str,
choices=[*QUANTIZATION_METHODS, None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
@ -349,7 +355,7 @@ class EngineArgs:
'asynchronous tokenization. Ignored ' 'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.') 'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config', parser.add_argument('--tokenizer-pool-extra-config',
type=str, type=nullable_str,
default=EngineArgs.tokenizer_pool_extra_config, default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. ' help='Extra config for tokenizer pool. '
'This should be a JSON string that will be ' 'This should be a JSON string that will be '
@ -404,7 +410,7 @@ class EngineArgs:
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
parser.add_argument( parser.add_argument(
'--image-input-type', '--image-input-type',
type=str, type=nullable_str,
default=None, default=None,
choices=[ choices=[
t.name.lower() for t in VisionLanguageConfig.ImageInputType t.name.lower() for t in VisionLanguageConfig.ImageInputType
@ -417,7 +423,7 @@ class EngineArgs:
help=('Input id for image token.')) help=('Input id for image token.'))
parser.add_argument( parser.add_argument(
'--image-input-shape', '--image-input-shape',
type=str, type=nullable_str,
default=None, default=None,
help=('The biggest image input shape (worst for memory footprint) ' help=('The biggest image input shape (worst for memory footprint) '
'given an input type. Only used for vLLM\'s profile_run.')) 'given an input type. Only used for vLLM\'s profile_run.'))
@ -440,7 +446,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
type=str, type=nullable_str,
default=EngineArgs.speculative_model, default=EngineArgs.speculative_model,
help= help=
'The name of the draft model to be used in speculative decoding.') 'The name of the draft model to be used in speculative decoding.')
@ -454,7 +460,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
'--speculative-max-model-len', '--speculative-max-model-len',
type=str, type=int,
default=EngineArgs.speculative_max_model_len, default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the ' help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip ' 'draft model. Sequences over this length will skip '
@ -475,7 +481,7 @@ class EngineArgs:
'decoding.') 'decoding.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=str, type=nullable_str,
default=EngineArgs.model_loader_extra_config, default=EngineArgs.model_loader_extra_config,
help='Extra config for model loader. ' help='Extra config for model loader. '
'This will be passed to the model loader ' 'This will be passed to the model loader '

View File

@ -8,7 +8,7 @@ import argparse
import json import json
import ssl import ssl
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath from vllm.entrypoints.openai.serving_engine import LoRAModulePath
@ -25,7 +25,10 @@ class LoRAParserAction(argparse.Action):
def make_arg_parser(): def make_arg_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name") parser.add_argument("--host",
type=nullable_str,
default=None,
help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number") parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument( parser.add_argument(
"--uvicorn-log-level", "--uvicorn-log-level",
@ -49,13 +52,13 @@ def make_arg_parser():
default=["*"], default=["*"],
help="allowed headers") help="allowed headers")
parser.add_argument("--api-key", parser.add_argument("--api-key",
type=str, type=nullable_str,
default=None, default=None,
help="If provided, the server will require this key " help="If provided, the server will require this key "
"to be presented in the header.") "to be presented in the header.")
parser.add_argument("--served-model-name", parser.add_argument("--served-model-name",
nargs="+", nargs="+",
type=str, type=nullable_str,
default=None, default=None,
help="The model name(s) used in the API. If multiple " help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any " "names are provided, the server will respond to any "
@ -65,33 +68,33 @@ def make_arg_parser():
"same as the `--model` argument.") "same as the `--model` argument.")
parser.add_argument( parser.add_argument(
"--lora-modules", "--lora-modules",
type=str, type=nullable_str,
default=None, default=None,
nargs='+', nargs='+',
action=LoRAParserAction, action=LoRAParserAction,
help="LoRA module configurations in the format name=path. " help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.") "Multiple modules can be specified.")
parser.add_argument("--chat-template", parser.add_argument("--chat-template",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the chat template, " help="The file path to the chat template, "
"or the template in single-line form " "or the template in single-line form "
"for the specified model") "for the specified model")
parser.add_argument("--response-role", parser.add_argument("--response-role",
type=str, type=nullable_str,
default="assistant", default="assistant",
help="The role name to return if " help="The role name to return if "
"`request.add_generation_prompt=true`.") "`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile", parser.add_argument("--ssl-keyfile",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the SSL key file") help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile", parser.add_argument("--ssl-certfile",
type=str, type=nullable_str,
default=None, default=None,
help="The file path to the SSL cert file") help="The file path to the SSL cert file")
parser.add_argument("--ssl-ca-certs", parser.add_argument("--ssl-ca-certs",
type=str, type=nullable_str,
default=None, default=None,
help="The CA certificates file") help="The CA certificates file")
parser.add_argument( parser.add_argument(
@ -102,12 +105,12 @@ def make_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--root-path", "--root-path",
type=str, type=nullable_str,
default=None, default=None,
help="FastAPI root_path when app is behind a path based routing proxy") help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument( parser.add_argument(
"--middleware", "--middleware",
type=str, type=nullable_str,
action="append", action="append",
default=[], default=[],
help="Additional ASGI middleware to apply to the app. " help="Additional ASGI middleware to apply to the app. "