Fix missing docs and out of sync EngineArgs (#4219)

Co-authored-by: Harry Mellor <hmellor@oxts.com>
This commit is contained in:
Harry Mellor 2024-04-20 04:51:33 +01:00 committed by GitHub
parent 138485a82d
commit 682789d402
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 98 additions and 197 deletions

View File

@ -11,12 +11,14 @@
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
import logging import logging
import os
import sys import sys
from typing import List from typing import List
from sphinx.ext import autodoc from sphinx.ext import autodoc
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sys.path.append(os.path.abspath("../.."))
# -- Project information ----------------------------------------------------- # -- Project information -----------------------------------------------------

View File

@ -5,133 +5,17 @@ Engine Arguments
Below, you can find an explanation of every engine argument for vLLM: Below, you can find an explanation of every engine argument for vLLM:
.. option:: --model <model_name_or_path> .. argparse::
:module: vllm.engine.arg_utils
Name or path of the huggingface model to use. :func: _engine_args_parser
:prog: -m vllm.entrypoints.openai.api_server
.. option:: --tokenizer <tokenizer_name_or_path>
Name or path of the huggingface tokenizer to use.
.. option:: --revision <revision>
The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
.. option:: --tokenizer-revision <revision>
The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.
.. option:: --tokenizer-mode {auto,slow}
The tokenizer mode.
* "auto" will use the fast tokenizer if available.
* "slow" will always use the slow tokenizer.
.. option:: --trust-remote-code
Trust remote code from huggingface.
.. option:: --download-dir <directory>
Directory to download and load the weights, default to the default cache dir of huggingface.
.. option:: --load-format {auto,pt,safetensors,npcache,dummy,tensorizer}
The format of the model weights to load.
* "auto" will try to load the weights in the safetensors format and fall back to the pytorch bin format if safetensors format is not available.
* "pt" will load the weights in the pytorch bin format.
* "safetensors" will load the weights in the safetensors format.
* "npcache" will load the weights in pytorch format and store a numpy cache to speed up the loading.
* "dummy" will initialize the weights with random values, mainly for profiling.
* "tensorizer" will load serialized weights using `CoreWeave's Tensorizer model deserializer. <https://github.com/coreweave/tensorizer>`_ See `examples/tensorize_vllm_model.py <https://github.com/vllm-project/vllm/blob/main/examples/tensorize_vllm_model.py>`_ to serialize a vLLM model, and for more information.
.. option:: --dtype {auto,half,float16,bfloat16,float,float32}
Data type for model weights and activations.
* "auto" will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models.
* "half" for FP16. Recommended for AWQ quantization.
* "float16" is the same as "half".
* "bfloat16" for a balance between precision and range.
* "float" is shorthand for FP32 precision.
* "float32" for FP32 precision.
.. option:: --max-model-len <length>
Model context length. If unspecified, will be automatically derived from the model config.
.. option:: --worker-use-ray
Use Ray for distributed serving, will be automatically set when using more than 1 GPU.
.. option:: --pipeline-parallel-size (-pp) <size>
Number of pipeline stages.
.. option:: --tensor-parallel-size (-tp) <size>
Number of tensor parallel replicas.
.. option:: --max-parallel-loading-workers <workers>
Load model sequentially in multiple batches, to avoid RAM OOM when using tensor parallel and large models.
.. option:: --block-size {8,16,32}
Token block size for contiguous chunks of tokens.
.. option:: --enable-prefix-caching
Enables automatic prefix caching
.. option:: --seed <seed>
Random seed for operations.
.. option:: --swap-space <size>
CPU swap space size (GiB) per GPU.
.. option:: --gpu-memory-utilization <fraction>
The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.
For example, a value of 0.5 would imply 50% GPU memory utilization.
If unspecified, will use the default value of 0.9.
.. option:: --max-num-batched-tokens <tokens>
Maximum number of batched tokens per iteration.
.. option:: --max-num-seqs <sequences>
Maximum number of sequences per iteration.
.. option:: --max-paddings <paddings>
Maximum number of paddings in a batch.
.. option:: --disable-log-stats
Disable logging statistics.
.. option:: --quantization (-q) {awq,squeezellm,None}
Method used to quantize the weights.
Async Engine Arguments Async Engine Arguments
---------------------- ----------------------
Below are the additional arguments related to the asynchronous engine: Below are the additional arguments related to the asynchronous engine:
.. option:: --engine-use-ray .. argparse::
:module: vllm.engine.arg_utils
Use Ray to start the LLM engine in a separate process as the server process. :func: _async_engine_args_parser
:prog: -m vllm.entrypoints.openai.api_server
.. option:: --disable-log-requests
Disable logging requests.
.. option:: --max-log-len
Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited.

View File

@ -82,57 +82,55 @@ class EngineArgs:
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# NOTE: If you update any of the arguments below, please also
# make sure to update docs/source/models/engine_args.rst
# Model arguments # Model arguments
parser.add_argument( parser.add_argument(
'--model', '--model',
type=str, type=str,
default='facebook/opt-125m', default='facebook/opt-125m',
help='name or path of the huggingface model to use') help='Name or path of the huggingface model to use.')
parser.add_argument( parser.add_argument(
'--tokenizer', '--tokenizer',
type=str, type=str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use') help='Name or path of the huggingface tokenizer to use.')
parser.add_argument( parser.add_argument(
'--revision', '--revision',
type=str, type=str,
default=None, default=None,
help='the specific model version to use. It can be a branch ' help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument( parser.add_argument(
'--code-revision', '--code-revision',
type=str, type=str,
default=None, default=None,
help='the specific revision to use for the model code on ' help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a ' 'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.') 'commit id. If unspecified, will use the default version.')
parser.add_argument( parser.add_argument(
'--tokenizer-revision', '--tokenizer-revision',
type=str, type=str,
default=None, default=None,
help='the specific tokenizer version to use. It can be a branch ' help='The specific tokenizer version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use ' 'name, a tag name, or a commit id. If unspecified, will use '
'the default version.') 'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument(
type=str, '--tokenizer-mode',
default=EngineArgs.tokenizer_mode, type=str,
choices=['auto', 'slow'], default=EngineArgs.tokenizer_mode,
help='tokenizer mode. "auto" will use the fast ' choices=['auto', 'slow'],
'tokenizer if available, and "slow" will ' help='The tokenizer mode.\n\n* "auto" will use the '
'always use the slow tokenizer.') 'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer.')
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
help='trust remote code from huggingface') help='Trust remote code from huggingface.')
parser.add_argument('--download-dir', parser.add_argument('--download-dir',
type=str, type=str,
default=EngineArgs.download_dir, default=EngineArgs.download_dir,
help='directory to download and load the weights, ' help='Directory to download and load the weights, '
'default to the default cache dir of ' 'default to the default cache dir of '
'huggingface') 'huggingface.')
parser.add_argument( parser.add_argument(
'--load-format', '--load-format',
type=str, type=str,
@ -140,19 +138,19 @@ class EngineArgs:
choices=[ choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer' 'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
], ],
help='The format of the model weights to load. ' help='The format of the model weights to load.\n\n'
'"auto" will try to load the weights in the safetensors format ' '* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format ' 'and fall back to the pytorch bin format if safetensors format '
'is not available. ' 'is not available.\n'
'"pt" will load the weights in the pytorch bin format. ' '* "pt" will load the weights in the pytorch bin format.\n'
'"safetensors" will load the weights in the safetensors format. ' '* "safetensors" will load the weights in the safetensors format.\n'
'"npcache" will load the weights in pytorch format and store ' '* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading.\n'
'"dummy" will initialize the weights with random values, ' '* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.' 'which is mainly for profiling.\n'
'"tensorizer" will load the weights using tensorizer from CoreWeave' '* "tensorizer" will load the weights using tensorizer from '
'which assumes tensorizer_uri is set to the location of the ' 'CoreWeave which assumes tensorizer_uri is set to the location of '
'serialized weights.') 'the serialized weights.')
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
@ -160,10 +158,14 @@ class EngineArgs:
choices=[ choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
], ],
help='data type for model weights and activations. ' help='Data type for model weights and activations.\n\n'
'The "auto" option will use FP16 precision ' '* "auto" will use FP16 precision for FP32 and FP16 models, and '
'for FP32 and FP16 models, and BF16 precision ' 'BF16 precision for BF16 models.\n'
'for BF16 models.') '* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument( parser.add_argument(
'--kv-cache-dtype', '--kv-cache-dtype',
type=str, type=str,
@ -172,7 +174,7 @@ class EngineArgs:
help='Data type for kv cache storage. If "auto", will use model ' help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda ' 'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria. ') 'supported for common inference criteria.')
parser.add_argument( parser.add_argument(
'--quantization-param-path', '--quantization-param-path',
type=str, type=str,
@ -183,58 +185,59 @@ class EngineArgs:
'default to 1.0, which may cause accuracy issues. ' 'default to 1.0, which may cause accuracy issues. '
'FP8_E5M2 (without scaling) is only supported on cuda version' 'FP8_E5M2 (without scaling) is only supported on cuda version'
'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria. ') 'supported for common inference criteria.')
parser.add_argument('--max-model-len', parser.add_argument('--max-model-len',
type=int, type=int,
default=EngineArgs.max_model_len, default=EngineArgs.max_model_len,
help='model context length. If unspecified, ' help='Model context length. If unspecified, will '
'will be automatically derived from the model.') 'be automatically derived from the model config.')
parser.add_argument( parser.add_argument(
'--guided-decoding-backend', '--guided-decoding-backend',
type=str, type=str,
default='outlines', default='outlines',
choices=['outlines', 'lm-format-enforcer'], choices=['outlines', 'lm-format-enforcer'],
help='Which engine will be used for guided decoding' help='Which engine will be used for guided decoding'
' (JSON schema / regex etc)') ' (JSON schema / regex etc).')
# Parallel arguments # Parallel arguments
parser.add_argument('--worker-use-ray', parser.add_argument('--worker-use-ray',
action='store_true', action='store_true',
help='use Ray for distributed serving, will be ' help='Use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU') 'automatically set when using more than 1 GPU.')
parser.add_argument('--pipeline-parallel-size', parser.add_argument('--pipeline-parallel-size',
'-pp', '-pp',
type=int, type=int,
default=EngineArgs.pipeline_parallel_size, default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages') help='Number of pipeline stages.')
parser.add_argument('--tensor-parallel-size', parser.add_argument('--tensor-parallel-size',
'-tp', '-tp',
type=int, type=int,
default=EngineArgs.tensor_parallel_size, default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas') help='Number of tensor parallel replicas.')
parser.add_argument( parser.add_argument(
'--max-parallel-loading-workers', '--max-parallel-loading-workers',
type=int, type=int,
default=EngineArgs.max_parallel_loading_workers, default=EngineArgs.max_parallel_loading_workers,
help='load model sequentially in multiple batches, ' help='Load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor ' 'to avoid RAM OOM when using tensor '
'parallel and large models') 'parallel and large models.')
parser.add_argument( parser.add_argument(
'--ray-workers-use-nsight', '--ray-workers-use-nsight',
action='store_true', action='store_true',
help='If specified, use nsight to profile ray workers') help='If specified, use nsight to profile Ray workers.')
# KV cache arguments # KV cache arguments
parser.add_argument('--block-size', parser.add_argument('--block-size',
type=int, type=int,
default=EngineArgs.block_size, default=EngineArgs.block_size,
choices=[8, 16, 32, 128], choices=[8, 16, 32, 128],
help='token block size') help='Token block size for contiguous chunks of '
'tokens.')
parser.add_argument('--enable-prefix-caching', parser.add_argument('--enable-prefix-caching',
action='store_true', action='store_true',
help='Enables automatic prefix caching') help='Enables automatic prefix caching.')
parser.add_argument('--use-v2-block-manager', parser.add_argument('--use-v2-block-manager',
action='store_true', action='store_true',
help='Use BlockSpaceMangerV2') help='Use BlockSpaceMangerV2.')
parser.add_argument( parser.add_argument(
'--num-lookahead-slots', '--num-lookahead-slots',
type=int, type=int,
@ -247,18 +250,19 @@ class EngineArgs:
parser.add_argument('--seed', parser.add_argument('--seed',
type=int, type=int,
default=EngineArgs.seed, default=EngineArgs.seed,
help='random seed') help='Random seed for operations.')
parser.add_argument('--swap-space', parser.add_argument('--swap-space',
type=int, type=int,
default=EngineArgs.swap_space, default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU') help='CPU swap space size (GiB) per GPU.')
parser.add_argument( parser.add_argument(
'--gpu-memory-utilization', '--gpu-memory-utilization',
type=float, type=float,
default=EngineArgs.gpu_memory_utilization, default=EngineArgs.gpu_memory_utilization,
help='the fraction of GPU memory to be used for ' help='The fraction of GPU memory to be used for the model '
'the model executor, which can range from 0 to 1.' 'executor, which can range from 0 to 1. For example, a value of '
'If unspecified, will use the default value of 0.9.') '0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9.')
parser.add_argument( parser.add_argument(
'--num-gpu-blocks-override', '--num-gpu-blocks-override',
type=int, type=int,
@ -268,21 +272,21 @@ class EngineArgs:
parser.add_argument('--max-num-batched-tokens', parser.add_argument('--max-num-batched-tokens',
type=int, type=int,
default=EngineArgs.max_num_batched_tokens, default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per ' help='Maximum number of batched tokens per '
'iteration') 'iteration.')
parser.add_argument('--max-num-seqs', parser.add_argument('--max-num-seqs',
type=int, type=int,
default=EngineArgs.max_num_seqs, default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration') help='Maximum number of sequences per iteration.')
parser.add_argument( parser.add_argument(
'--max-logprobs', '--max-logprobs',
type=int, type=int,
default=EngineArgs.max_logprobs, default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in' help=('Max number of log probs to return logprobs is specified in'
' SamplingParams')) ' SamplingParams.'))
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='Disable logging statistics.')
# Quantization settings. # Quantization settings.
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
@ -303,13 +307,13 @@ class EngineArgs:
parser.add_argument('--max-context-len-to-capture', parser.add_argument('--max-context-len-to-capture',
type=int, type=int,
default=EngineArgs.max_context_len_to_capture, default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA ' help='Maximum context length covered by CUDA '
'graphs. When a sequence has context length ' 'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.') 'larger than this, we fall back to eager mode.')
parser.add_argument('--disable-custom-all-reduce', parser.add_argument('--disable-custom-all-reduce',
action='store_true', action='store_true',
default=EngineArgs.disable_custom_all_reduce, default=EngineArgs.disable_custom_all_reduce,
help='See ParallelConfig') help='See ParallelConfig.')
parser.add_argument('--tokenizer-pool-size', parser.add_argument('--tokenizer-pool-size',
type=int, type=int,
default=EngineArgs.tokenizer_pool_size, default=EngineArgs.tokenizer_pool_size,
@ -402,7 +406,7 @@ class EngineArgs:
'--enable-chunked-prefill', '--enable-chunked-prefill',
action='store_true', action='store_true',
help='If set, the prefill requests can be chunked based on the ' help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens') 'max_num_batched_tokens.')
parser.add_argument( parser.add_argument(
'--speculative-model', '--speculative-model',
@ -416,7 +420,7 @@ class EngineArgs:
type=int, type=int,
default=None, default=None,
help='The number of speculative tokens to sample from ' help='The number of speculative tokens to sample from '
'the draft model in speculative decoding') 'the draft model in speculative decoding.')
parser.add_argument('--model-loader-extra-config', parser.add_argument('--model-loader-extra-config',
type=str, type=str,
@ -534,20 +538,31 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None max_log_len: Optional[int] = None
@staticmethod @staticmethod
def add_cli_args( def add_cli_args(parser: argparse.ArgumentParser,
parser: argparse.ArgumentParser) -> argparse.ArgumentParser: async_args_only: bool = False) -> argparse.ArgumentParser:
parser = EngineArgs.add_cli_args(parser) if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', parser.add_argument('--engine-use-ray',
action='store_true', action='store_true',
help='use Ray to start the LLM engine in a ' help='Use Ray to start the LLM engine in a '
'separate process as the server process.') 'separate process as the server process.')
parser.add_argument('--disable-log-requests', parser.add_argument('--disable-log-requests',
action='store_true', action='store_true',
help='disable logging requests') help='Disable logging requests.')
parser.add_argument('--max-log-len', parser.add_argument('--max-log-len',
type=int, type=int,
default=None, default=None,
help='max number of prompt characters or prompt ' help='Max number of prompt characters or prompt '
'ID numbers being printed in log. ' 'ID numbers being printed in log.'
'Default: unlimited.') '\n\nDefault: Unlimited')
return parser return parser
# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
async_args_only=True)