[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer (#7836)

This commit is contained in:
Cyrus Leung 2024-08-26 13:31:10 +08:00 committed by GitHub
parent 0b769992ec
commit 029c71de11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 15 deletions

View File

@ -11,13 +11,14 @@ from typing import Any, Callable, Dict, List, Optional
import openai import openai
import requests import requests
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer from transformers import AutoTokenizer
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
@ -60,39 +61,50 @@ class RemoteOpenAIServer:
def __init__(self, def __init__(self,
model: str, model: str,
cli_args: List[str], vllm_serve_args: List[str],
*, *,
env_dict: Optional[Dict[str, str]] = None, env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True, auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None: max_wait_seconds: Optional[float] = None) -> None:
if not model.startswith("/"):
# download the model if it's not a local path
# to exclude the model download time from the server start time
snapshot_download(model)
if auto_port: if auto_port:
if "-p" in cli_args or "--port" in cli_args: if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port" raise ValueError("You have manually specified the port "
"when `auto_port=True`.") "when `auto_port=True`.")
cli_args = cli_args + ["--port", str(get_open_port())] # Don't mutate the input args
vllm_serve_args = vllm_serve_args + [
"--port", str(get_open_port())
]
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.") description="vLLM's remote OpenAI server.")
parser = make_arg_parser(parser) parser = make_arg_parser(parser)
args = parser.parse_args(cli_args) args = parser.parse_args(["--model", model, *vllm_serve_args])
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
# download the model before starting the server to avoid timeout
is_local = os.path.isdir(model)
if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config()
dummy_loader = DefaultModelLoader(engine_config.load_config)
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision,
fall_back_to_pt=True)
env = os.environ.copy() env = os.environ.copy()
# the current process might initialize cuda, # the current process might initialize cuda,
# to be safe, we should use spawn method # to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None: if env_dict is not None:
env.update(env_dict) env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, self.proc = subprocess.Popen(
env=env, ["vllm", "serve", model, *vllm_serve_args],
stdout=sys.stdout, env=env,
stderr=sys.stderr) stdout=sys.stdout,
stderr=sys.stderr,
)
max_wait_seconds = max_wait_seconds or 240 max_wait_seconds = max_wait_seconds or 240
self._wait_for_server(url=self.url_for("health"), self._wait_for_server(url=self.url_for("health"),
timeout=max_wait_seconds) timeout=max_wait_seconds)

View File

@ -742,7 +742,7 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args return engine_args
def create_engine_config(self, ) -> EngineConfig: def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo # gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"): if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"