From 97356f3c7e2fcaf1f8e17300eaf0b20b35eccb9d Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Fri, 29 Mar 2024 12:27:51 -0700 Subject: [PATCH] [Bugfix] Command-R Max Model Length (#3727) --- vllm/config.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 265cfa56..62f1d700 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -765,15 +765,20 @@ def _get_and_verify_max_len( "max_seq_len", # ChatGLM2 "seq_length", + # Command-R + "model_max_length", # Others "max_sequence_length", "max_seq_length", "seq_len", ] + max_len_key = None for key in possible_keys: - max_len_key = getattr(hf_config, key, None) - if max_len_key is not None: - derived_max_model_len = min(derived_max_model_len, max_len_key) + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) if derived_max_model_len == float("inf"): if max_model_len is not None: # If max_model_len is specified, we use it. @@ -799,10 +804,18 @@ def _get_and_verify_max_len( if max_model_len is None: max_model_len = derived_max_model_len elif max_model_len > derived_max_model_len: - raise ValueError( - f"User-specified max_model_len ({max_model_len}) is greater than " - f"the derived max_model_len ({max_len_key}={derived_max_model_len}" - " in model's config.json). This may lead to incorrect model " - "outputs or CUDA errors. Make sure the value is correct and " - "within the model context size.") + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + pass + else: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater " + "than the derived max_model_len " + f"({max_len_key}={derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors. Make sure the " + "value is correct and within the model context size.") return int(max_model_len)