[ci] try to add multi-node tests (#6280)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
youkaichao 2024-07-12 21:51:48 -07:00 committed by GitHub
parent d80aef3776
commit 41708e5034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 229 additions and 274 deletions

View File

@ -2,16 +2,17 @@
set -euox pipefail set -euox pipefail
if [[ $# -lt 3 ]]; then if [[ $# -lt 4 ]]; then
echo "Please provide the number of nodes and GPU per node." echo "Usage: .buildkite/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
exit 1 exit 1
fi fi
NUM_NODES=$1 WORKING_DIR=$1
NUM_GPUS=$2 NUM_NODES=$2
DOCKER_IMAGE=$3 NUM_GPUS=$3
DOCKER_IMAGE=$4
shift 3 shift 4
COMMANDS=("$@") COMMANDS=("$@")
if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then
echo "The number of commands must be equal to the number of nodes." echo "The number of commands must be equal to the number of nodes."
@ -40,13 +41,40 @@ start_nodes() {
fi fi
done done
GPU_DEVICES+='"' GPU_DEVICES+='"'
# echo "Starting node$node with GPU devices: $GPU_DEVICES"
docker run -d --gpus "$GPU_DEVICES" --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE tail -f /dev/null # start the container in detached mode
# things to note:
# 1. --shm-size=10.24gb is required. don't use --ipc=host
# 2. pass HF_TOKEN to the container
# 3. map the huggingface cache directory to the container
# 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes:
# starting from 192.168.10.11)
docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN -v ~/.cache/huggingface:/root/.cache/huggingface --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE /bin/bash -c "tail -f /dev/null"
# organize containers into a ray cluster
if [ $node -eq 0 ]; then
# start the ray head node
docker exec -d node$node /bin/bash -c "ray start --head --port=6379 --block"
# wait for the head node to be ready
sleep 10
else
# start the ray worker nodes, and connect them to the head node
docker exec -d node$node /bin/bash -c "ray start --address=192.168.10.10:6379 --block"
fi
done done
# wait for the cluster to be ready
sleep 10
# print the cluster status
docker exec node0 /bin/bash -c "ray status"
} }
run_nodes() { run_nodes() {
for node in $(seq 0 $(($NUM_NODES-1))); do # important: iterate in reverse order to start the head node last
# we start the worker nodes first, in detached mode, and then start the head node
# in the foreground, so that the output of the head node is visible in the buildkite logs
for node in $(seq $(($NUM_NODES - 1)) -1 0); do
GPU_DEVICES='"device=' GPU_DEVICES='"device='
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu)) DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
@ -57,10 +85,10 @@ run_nodes() {
done done
GPU_DEVICES+='"' GPU_DEVICES+='"'
echo "Running node$node with GPU devices: $GPU_DEVICES" echo "Running node$node with GPU devices: $GPU_DEVICES"
if [ $node -lt $(($NUM_NODES - 1)) ]; then if [ $node -ne 0 ]; then
docker exec -d node$node /bin/bash -c "${COMMANDS[$node]}" docker exec -d node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
else else
docker exec node$node /bin/bash -c "${COMMANDS[$node]}" docker exec node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
fi fi
done done
} }

View File

@ -68,6 +68,17 @@ steps:
- pytest -v -s distributed/test_comm_ops.py - pytest -v -s distributed/test_comm_ops.py
- pytest -v -s distributed/test_shm_broadcast.py - pytest -v -s distributed/test_shm_broadcast.py
- label: 2 Node Tests (4 GPUs in total)
working_dir: "/vllm-workspace/tests"
num_gpus: 2
num_nodes: 2
commands:
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
- label: Distributed Tests (2 GPUs) - label: Distributed Tests (2 GPUs)
mirror_hardwares: [amd] mirror_hardwares: [amd]
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
@ -213,7 +224,10 @@ steps:
- label: Tensorizer Test - label: Tensorizer Test
#mirror_hardwares: [amd] #mirror_hardwares: [amd]
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader commands:
- apt-get install curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s tensorizer_loader
- label: Metrics Test - label: Metrics Test
mirror_hardwares: [amd] mirror_hardwares: [amd]

View File

@ -1,35 +1,26 @@
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer from ..utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m" MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server():
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
@pytest.fixture(scope="module") "float16",
def server(ray_ctx): "--max-model-len",
return RemoteOpenAIServer([ "2048",
"--model", "--enforce-eager",
MODEL_NAME, "--engine-use-ray"
# use half precision for speed and memory savings in CI environment ]) as remote_server:
"--dtype", yield remote_server
"float16",
"--max-model-len",
"2048",
"--enforce-eager",
"--engine-use-ray"
])
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -2,11 +2,8 @@ import os
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
from ..utils import VLLM_PATH, RemoteOpenAIServer from ..utils import RemoteOpenAIServer
# downloading lora to test lora requests # downloading lora to test lora requests
@ -21,14 +18,7 @@ pytestmark = pytest.mark.asyncio
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server():
ray.init(runtime_env={"working_dir": VLLM_PATH})
yield
ray.shutdown()
@pytest.fixture(scope="module")
def server(ray_ctx):
args = [ args = [
"--model", "--model",
MODEL_NAME, MODEL_NAME,
@ -50,7 +40,8 @@ def server(ray_ctx):
args += [ args += [
"--enforce-eager", "--enforce-eager",
] ]
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE) with RemoteOpenAIServer(args) as remote_server:
yield remote_server
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -10,3 +10,4 @@ test_result = all(
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1" expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
assert test_result == expected, f"Expected {expected}, got {test_result}" assert test_result == expected, f"Expected {expected}, got {test_result}"
print("Same node test passed!")

View File

@ -6,15 +6,12 @@ from typing import List
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import torch import torch
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from openai import BadRequestError from openai import BadRequestError
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -29,35 +26,29 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server(zephyr_lora_files):
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
@pytest.fixture(scope="module") "bfloat16",
def server(zephyr_lora_files, ray_ctx): "--max-model-len",
return RemoteOpenAIServer([ "8192",
"--model", "--enforce-eager",
MODEL_NAME, # lora config below
# use half precision for speed and memory savings in CI environment "--enable-lora",
"--dtype", "--lora-modules",
"bfloat16", f"zephyr-lora={zephyr_lora_files}",
"--max-model-len", f"zephyr-lora2={zephyr_lora_files}",
"8192", "--max-lora-rank",
"--enforce-eager", "64",
# lora config below "--max-cpu-loras",
"--enable-lora", "2",
"--lora-modules", "--max-num-seqs",
f"zephyr-lora={zephyr_lora_files}", "128",
f"zephyr-lora2={zephyr_lora_files}", ]) as remote_server:
"--max-lora-rank", yield remote_server
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128",
])
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -6,9 +6,6 @@ from typing import List
import jsonschema import jsonschema
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests import requests
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -16,7 +13,7 @@ from openai import BadRequestError
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -31,35 +28,29 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server(zephyr_lora_files):
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
@pytest.fixture(scope="module") "bfloat16",
def server(zephyr_lora_files, ray_ctx): "--max-model-len",
return RemoteOpenAIServer([ "8192",
"--model", "--enforce-eager",
MODEL_NAME, # lora config below
# use half precision for speed and memory savings in CI environment "--enable-lora",
"--dtype", "--lora-modules",
"bfloat16", f"zephyr-lora={zephyr_lora_files}",
"--max-model-len", f"zephyr-lora2={zephyr_lora_files}",
"8192", "--max-lora-rank",
"--enforce-eager", "64",
# lora config below "--max-cpu-loras",
"--enable-lora", "2",
"--lora-modules", "--max-num-seqs",
f"zephyr-lora={zephyr_lora_files}", "128",
f"zephyr-lora2={zephyr_lora_files}", ]) as remote_server:
"--max-lora-rank", yield remote_server
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128",
])
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,33 +3,26 @@ import base64
import numpy as np import numpy as np
import openai import openai
import pytest import pytest
import ray
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import RemoteOpenAIServer
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct" EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def embedding_server():
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() EMBEDDING_MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
@pytest.fixture(scope="module") "bfloat16",
def embedding_server(ray_ctx): "--enforce-eager",
return RemoteOpenAIServer([ "--max-model-len",
"--model", "8192",
EMBEDDING_MODEL_NAME, "--enforce-eager",
# use half precision for speed and memory savings in CI environment ]) as remote_server:
"--dtype", yield remote_server
"bfloat16",
"--enforce-eager",
"--max-model-len",
"8192",
"--enforce-eager",
])
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -1,12 +1,9 @@
import openai # use the official client for correctness check import openai # use the official client for correctness check
import pytest import pytest
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
# downloading lora to test lora requests # downloading lora to test lora requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from ...utils import VLLM_PATH, RemoteOpenAIServer from ...utils import RemoteOpenAIServer
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
@ -21,35 +18,29 @@ def zephyr_lora_files():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server(zephyr_lora_files):
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() MODEL_NAME,
# use half precision for speed and memory savings in CI environment
"--dtype",
@pytest.fixture(scope="module") "bfloat16",
def server(zephyr_lora_files, ray_ctx): "--max-model-len",
return RemoteOpenAIServer([ "8192",
"--model", "--enforce-eager",
MODEL_NAME, # lora config below
# use half precision for speed and memory savings in CI environment "--enable-lora",
"--dtype", "--lora-modules",
"bfloat16", f"zephyr-lora={zephyr_lora_files}",
"--max-model-len", f"zephyr-lora2={zephyr_lora_files}",
"8192", "--max-lora-rank",
"--enforce-eager", "64",
# lora config below "--max-cpu-loras",
"--enable-lora", "2",
"--lora-modules", "--max-num-seqs",
f"zephyr-lora={zephyr_lora_files}", "128",
f"zephyr-lora2={zephyr_lora_files}", ]) as remote_server:
"--max-lora-rank", yield remote_server
"64",
"--max-cpu-loras",
"2",
"--max-num-seqs",
"128",
])
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -3,7 +3,6 @@ from typing import Dict, List
import openai import openai
import pytest import pytest
import pytest_asyncio import pytest_asyncio
import ray
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64 from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
@ -23,25 +22,19 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ray_ctx(): def server():
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer([
yield "--model",
ray.shutdown() MODEL_NAME,
"--dtype",
"bfloat16",
@pytest.fixture(scope="module") "--max-model-len",
def server(ray_ctx): "4096",
return RemoteOpenAIServer([ "--enforce-eager",
"--model", "--chat-template",
MODEL_NAME, str(LLAVA_CHAT_TEMPLATE),
"--dtype", ]) as remote_server:
"bfloat16", yield remote_server
"--max-model-len",
"4096",
"--enforce-eager",
"--chat-template",
str(LLAVA_CHAT_TEMPLATE),
])
@pytest.fixture(scope="module") @pytest.fixture(scope="module")

View File

@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
import openai import openai
import pytest import pytest
import ray
import torch import torch
from tensorizer import EncryptionParams from tensorizer import EncryptionParams
@ -22,7 +21,7 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
tensorize_vllm_model) tensorize_vllm_model)
from ..conftest import VllmRunner, cleanup from ..conftest import VllmRunner, cleanup
from ..utils import VLLM_PATH, RemoteOpenAIServer from ..utils import RemoteOpenAIServer
# yapf conflicts with isort for this docstring # yapf conflicts with isort for this docstring
@ -220,23 +219,21 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
json.dumps(model_loader_extra_config), json.dumps(model_loader_extra_config),
] ]
ray.init(runtime_env={"working_dir": VLLM_PATH}) with RemoteOpenAIServer(openai_args) as server:
print("Server ready.")
server = RemoteOpenAIServer(openai_args) client = server.get_client()
print("Server ready.") completion = client.completions.create(model=model_ref,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
client = server.get_client() assert completion.id is not None
completion = client.completions.create(model=model_ref, assert len(completion.choices) == 1
prompt="Hello, my name is", assert len(completion.choices[0].text) >= 5
max_tokens=5, assert completion.choices[0].finish_reason == "length"
temperature=0.0) assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
assert completion.id is not None
assert len(completion.choices) == 1
assert len(completion.choices[0].text) >= 5
assert completion.choices[0].finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
def test_raise_value_error_on_invalid_load_format(vllm_runner): def test_raise_value_error_on_invalid_load_format(vllm_runner):
@ -282,7 +279,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
base_model.model.llm_engine.model_executor.shutdown() base_model.model.llm_engine.model_executor.shutdown()
del base_model del base_model
cleanup() cleanup()
ray.shutdown()
# load model with two shards and serialize with encryption # load model with two shards and serialize with encryption
model_path = str(tmp_path / (model_ref + "-%02d.tensors")) model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
@ -305,7 +301,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
assert os.path.isfile(model_path % 0), "Serialization subprocess failed" assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
assert os.path.isfile(model_path % 1), "Serialization subprocess failed" assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
cleanup() cleanup()
ray.shutdown()
loaded_vllm_model = vllm_runner( loaded_vllm_model = vllm_runner(
model_ref, model_ref,

View File

@ -49,53 +49,7 @@ class RemoteOpenAIServer:
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
class _RemoteRunner: def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
def __init__(self, cli_args: List[str], *, wait_url: str,
wait_timeout: float) -> None:
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
self.proc = subprocess.Popen(
[
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
*cli_args
],
env=env,
stdout=sys.stdout,
stderr=sys.stderr,
)
self._wait_for_server(url=wait_url, timeout=wait_timeout)
def ready(self):
return True
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
break
except Exception as err:
if self.proc.poll() is not None:
raise RuntimeError(
"Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError(
"Server failed to start in time.") from err
def __del__(self):
if hasattr(self, "proc"):
self.proc.terminate()
def __init__(self,
cli_args: List[str],
*,
auto_port: bool = True,
num_gpus: int = 1) -> None:
if auto_port: if auto_port:
if "-p" in cli_args or "--port" in cli_args: if "-p" in cli_args or "--port" in cli_args:
raise ValueError("You have manually specified the port" raise ValueError("You have manually specified the port"
@ -108,13 +62,41 @@ class RemoteOpenAIServer:
self.host = str(args.host or 'localhost') self.host = str(args.host or 'localhost')
self.port = int(args.port) self.port = int(args.port)
self._runner = ray.remote(num_gpus=num_gpus)( env = os.environ.copy()
self._RemoteRunner).remote( # the current process might initialize cuda,
cli_args, # to be safe, we should use spawn method
wait_url=self.url_for("health"), env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
wait_timeout=self.MAX_SERVER_START_WAIT_S) self.proc = subprocess.Popen(
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
cli_args,
env=env,
stdout=sys.stdout,
stderr=sys.stderr)
self._wait_for_server(url=self.url_for("health"),
timeout=self.MAX_SERVER_START_WAIT_S)
self._wait_until_ready() def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate()
def _wait_for_server(self, *, url: str, timeout: float):
# run health check
start = time.time()
while True:
try:
if requests.get(url).status_code == 200:
break
except Exception as err:
result = self.proc.poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from err
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError(
"Server failed to start in time.") from err
@property @property
def url_root(self) -> str: def url_root(self) -> str:
@ -123,9 +105,6 @@ class RemoteOpenAIServer:
def url_for(self, *parts: str) -> str: def url_for(self, *parts: str) -> str:
return self.url_root + "/" + "/".join(parts) return self.url_root + "/" + "/".join(parts)
def _wait_until_ready(self) -> None:
ray.get(self._runner.ready.remote())
def get_client(self): def get_client(self):
return openai.OpenAI( return openai.OpenAI(
base_url=self.url_for("v1"), base_url=self.url_for("v1"),

View File

@ -224,16 +224,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to. # broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = [] self.non_driver_workers: List[RayWorkerWrapper] = []
for pp_rank in range(self.parallel_config.pipeline_parallel_size): for idx, rank in enumerate(worker_ranks[1:]):
for tp_rank in range(self.parallel_config.tensor_parallel_size): # We need to skip the driver worker, which we
rank = (pp_rank * # do by skipping worker_ranks[0] which is always 0.
self.parallel_config.tensor_parallel_size) + tp_rank if rank % self.parallel_config.tensor_parallel_size == 0:
if rank == 0: self.tp_driver_workers.append(self.workers[idx])
pass else:
elif rank % self.parallel_config.tensor_parallel_size == 0: self.non_driver_workers.append(self.workers[idx])
self.tp_driver_workers.append(self.workers[rank - 1])
else:
self.non_driver_workers.append(self.workers[rank - 1])
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]