[CI/Build] Avoid downloading all HF files in RemoteOpenAIServer (#7836)
This commit is contained in:
parent
0b769992ec
commit
029c71de11
@ -11,13 +11,14 @@ from typing import Any, Callable, Dict, List, Optional
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
||||||
|
|
||||||
@ -60,39 +61,50 @@ class RemoteOpenAIServer:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: str,
|
model: str,
|
||||||
cli_args: List[str],
|
vllm_serve_args: List[str],
|
||||||
*,
|
*,
|
||||||
env_dict: Optional[Dict[str, str]] = None,
|
env_dict: Optional[Dict[str, str]] = None,
|
||||||
auto_port: bool = True,
|
auto_port: bool = True,
|
||||||
max_wait_seconds: Optional[float] = None) -> None:
|
max_wait_seconds: Optional[float] = None) -> None:
|
||||||
if not model.startswith("/"):
|
|
||||||
# download the model if it's not a local path
|
|
||||||
# to exclude the model download time from the server start time
|
|
||||||
snapshot_download(model)
|
|
||||||
if auto_port:
|
if auto_port:
|
||||||
if "-p" in cli_args or "--port" in cli_args:
|
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||||
raise ValueError("You have manually specified the port"
|
raise ValueError("You have manually specified the port "
|
||||||
"when `auto_port=True`.")
|
"when `auto_port=True`.")
|
||||||
|
|
||||||
cli_args = cli_args + ["--port", str(get_open_port())]
|
# Don't mutate the input args
|
||||||
|
vllm_serve_args = vllm_serve_args + [
|
||||||
|
"--port", str(get_open_port())
|
||||||
|
]
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="vLLM's remote OpenAI server.")
|
description="vLLM's remote OpenAI server.")
|
||||||
parser = make_arg_parser(parser)
|
parser = make_arg_parser(parser)
|
||||||
args = parser.parse_args(cli_args)
|
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
||||||
self.host = str(args.host or 'localhost')
|
self.host = str(args.host or 'localhost')
|
||||||
self.port = int(args.port)
|
self.port = int(args.port)
|
||||||
|
|
||||||
|
# download the model before starting the server to avoid timeout
|
||||||
|
is_local = os.path.isdir(model)
|
||||||
|
if not is_local:
|
||||||
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
dummy_loader = DefaultModelLoader(engine_config.load_config)
|
||||||
|
dummy_loader._prepare_weights(engine_config.model_config.model,
|
||||||
|
engine_config.model_config.revision,
|
||||||
|
fall_back_to_pt=True)
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
# the current process might initialize cuda,
|
# the current process might initialize cuda,
|
||||||
# to be safe, we should use spawn method
|
# to be safe, we should use spawn method
|
||||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||||
if env_dict is not None:
|
if env_dict is not None:
|
||||||
env.update(env_dict)
|
env.update(env_dict)
|
||||||
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
|
self.proc = subprocess.Popen(
|
||||||
env=env,
|
["vllm", "serve", model, *vllm_serve_args],
|
||||||
stdout=sys.stdout,
|
env=env,
|
||||||
stderr=sys.stderr)
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr,
|
||||||
|
)
|
||||||
max_wait_seconds = max_wait_seconds or 240
|
max_wait_seconds = max_wait_seconds or 240
|
||||||
self._wait_for_server(url=self.url_for("health"),
|
self._wait_for_server(url=self.url_for("health"),
|
||||||
timeout=max_wait_seconds)
|
timeout=max_wait_seconds)
|
||||||
|
|||||||
@ -742,7 +742,7 @@ class EngineArgs:
|
|||||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
return engine_args
|
return engine_args
|
||||||
|
|
||||||
def create_engine_config(self, ) -> EngineConfig:
|
def create_engine_config(self) -> EngineConfig:
|
||||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||||
if self.model.endswith(".gguf"):
|
if self.model.endswith(".gguf"):
|
||||||
self.quantization = self.load_format = "gguf"
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user