[Frontend] Warn if user max_model_len is greater than derived max_model_len (#7080)
Signed-off-by: Jefferson Fialho <jfialho@ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
44dcb52e39
commit
825b044863
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
@ -1541,15 +1542,21 @@ def _get_and_verify_max_len(
|
||||
"Disabling sliding window is not supported for models "
|
||||
"model_max_length in the config. Please raise an issue "
|
||||
"so we can investigate.")
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"User-specified max_model_len ({max_model_len}) is greater "
|
||||
"than the derived max_model_len "
|
||||
f"({max_len_key}={derived_max_model_len} or model_max_length="
|
||||
f"than the derived max_model_len ({max_len_key}="
|
||||
f"{derived_max_model_len} or model_max_length="
|
||||
f"{model_max_length} in model's config.json). This may lead "
|
||||
"to incorrect model outputs or CUDA errors. Make sure the "
|
||||
"value is correct and within the model context size.")
|
||||
"to incorrect model outputs or CUDA errors.")
|
||||
if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN:
|
||||
logger.warning(
|
||||
"%s Make sure the value is correct and within the "
|
||||
"model context size.", msg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{msg} To allow overriding this maximum, set "
|
||||
"the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1")
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
|
||||
10
vllm/envs.py
10
vllm/envs.py
@ -50,6 +50,7 @@ if TYPE_CHECKING:
|
||||
VLLM_NO_DEPRECATION_WARNING: bool = False
|
||||
CMAKE_BUILD_TYPE: Optional[str] = None
|
||||
VERBOSE: bool = False
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -331,6 +332,15 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# If set, vllm will skip the deprecation warnings.
|
||||
"VLLM_NO_DEPRECATION_WARNING":
|
||||
lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))),
|
||||
|
||||
# If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows
|
||||
# the user to specify a max sequence length greater than
|
||||
# the max length derived from the model's config.json.
|
||||
# To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1.
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN":
|
||||
lambda:
|
||||
(os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in
|
||||
("1", "true")),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
Loading…
Reference in New Issue
Block a user