From a19bc5c6281cb1d539043acb06699bf8438bb254 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 27 Sep 2023 16:34:00 -0700 Subject: [PATCH] Automatically configure `max_num_batched_tokens` (#1198) --- vllm/config.py | 43 +++++++++++++++++++++++++++++++--------- vllm/engine/arg_utils.py | 3 +-- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 328ba94c..5fc7696b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -266,11 +266,36 @@ class SchedulerConfig: and generated text). """ - def __init__(self, max_num_batched_tokens: int, max_num_seqs: int, - max_model_len: int) -> None: - self.max_num_batched_tokens = max_num_batched_tokens + def __init__( + self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + ) -> None: + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + # If max_model_len is too short, use 2048 as the default value for + # higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) self.max_num_seqs = max_num_seqs self.max_model_len = max_model_len + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") _STR_DTYPE_TO_TORCH_DTYPE = { @@ -350,14 +375,14 @@ def _get_and_verify_max_len( max_len_key = getattr(hf_config, key, None) if max_len_key is not None: derived_max_model_len = min(derived_max_model_len, max_len_key) + if derived_max_model_len == float("inf"): + raise ValueError( + "The model's config.json must contain one of the following keys " + "to determine the original maximum length of the model: " + f"{possible_keys}") rope_scaling = getattr(hf_config, "rope_scaling", None) if rope_scaling is not None: - if derived_max_model_len == float("inf"): - raise ValueError( - "When using rope_scaling, the model's config.json must " - "contain one of the following keys to determine the original " - f"maximum length of the model: {possible_keys}") assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] derived_max_model_len *= scaling_factor @@ -371,4 +396,4 @@ def _get_and_verify_max_len( " in model's config.json). This may lead to incorrect model " "outputs or CUDA errors. Make sure the value is correct and " "within the model context size.") - return max_model_len + return int(max_model_len) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 65a5d74f..237782df 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -25,7 +25,7 @@ class EngineArgs: block_size: int = 16 swap_space: int = 4 # GiB gpu_memory_utilization: float = 0.90 - max_num_batched_tokens: int = 2560 + max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 disable_log_stats: bool = False revision: Optional[str] = None @@ -34,7 +34,6 @@ class EngineArgs: def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model - self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) @staticmethod def add_cli_args(