diff --git a/tests/test_utils.py b/tests/test_utils.py index a6c3896f..0b674ea6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,13 @@ import asyncio +import os +import socket import sys from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, Tuple, TypeVar) import pytest -from vllm.utils import deprecate_kwargs, merge_async_iterators +from vllm.utils import deprecate_kwargs, get_open_port, merge_async_iterators from .utils import error_on_warning @@ -116,3 +118,15 @@ def test_deprecate_kwargs_additional_message(): with pytest.warns(DeprecationWarning, match="abcd"): dummy(old_arg=1) + + +def test_get_open_port(): + os.environ["VLLM_PORT"] = "5678" + # make sure we can get multiple ports, even if the env var is set + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: + s1.bind(("localhost", get_open_port())) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: + s2.bind(("localhost", get_open_port())) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: + s3.bind(("localhost", get_open_port())) + os.environ.pop("VLLM_PORT") diff --git a/vllm/envs.py b/vllm/envs.py index bef343d0..7d5c7371 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -99,6 +99,9 @@ environment_variables: Dict[str, Callable[[], Any]] = { lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), # used in distributed environment to manually set the communication port + # Note: if VLLM_PORT is set, and some code asks for multiple ports, the + # VLLM_PORT will be used as the first port, and the rest will be generated + # by incrementing the VLLM_PORT value. # '0' is used to make mypy happy 'VLLM_PORT': lambda: int(os.getenv('VLLM_PORT', '0')) diff --git a/vllm/utils.py b/vllm/utils.py index 2781eceb..2bd24d08 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -289,7 +289,15 @@ def get_distributed_init_method(ip: str, port: int) -> str: def get_open_port() -> int: port = envs.VLLM_PORT if port is not None: - return port + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", + port - 1, port) # try ipv4 try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: