diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 66c7a8dd..804b2fb2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -269,7 +269,7 @@ steps: - csrc/ - vllm/model_executor/layers/quantization - tests/quantization - command: pytest -v -s quantization + command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" diff --git a/tests/utils.py b/tests/utils.py index 020c33b8..115cab80 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -16,6 +16,7 @@ import requests from openai.types.completion import Completion from typing_extensions import ParamSpec, assert_never +import vllm.envs as envs from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) @@ -352,10 +353,26 @@ def compare_all_settings(model: str, tokenizer_mode=tokenizer_mode, ) + can_force_load_format = True + + for args in all_args: + if "--load-format" in args: + can_force_load_format = False + break + prompt = "Hello, my name is" token_ids = tokenizer(prompt).input_ids ref_results: List = [] for i, (args, env) in enumerate(zip(all_args, all_envs)): + if can_force_load_format: + # we are comparing the results and + # usually we don't need real weights. + # we force to use dummy weights by default, + # and it should work for most of the cases. + # if not, we can use VLLM_TEST_FORCE_LOAD_FORMAT + # environment variable to force the load format, + # e.g. in quantization tests. + args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT] compare_results: List = [] results = ref_results if i == 0 else compare_results with RemoteOpenAIServer(model, diff --git a/vllm/envs.py b/vllm/envs.py index d15cded4..f65f5c6b 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -397,6 +397,8 @@ environment_variables: Dict[str, Callable[[], Any]] = { lambda: (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in ("1", "true")), + "VLLM_TEST_FORCE_LOAD_FORMAT": + lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations