[ci] try to add multi-node tests (#6280)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai> Co-authored-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
parent
d80aef3776
commit
41708e5034
@ -2,16 +2,17 @@
|
|||||||
|
|
||||||
set -euox pipefail
|
set -euox pipefail
|
||||||
|
|
||||||
if [[ $# -lt 3 ]]; then
|
if [[ $# -lt 4 ]]; then
|
||||||
echo "Please provide the number of nodes and GPU per node."
|
echo "Usage: .buildkite/run-multi-node-test.sh WORKING_DIR NUM_NODES NUM_GPUS DOCKER_IMAGE COMMAND1 COMMAND2 ... COMMANDN"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
NUM_NODES=$1
|
WORKING_DIR=$1
|
||||||
NUM_GPUS=$2
|
NUM_NODES=$2
|
||||||
DOCKER_IMAGE=$3
|
NUM_GPUS=$3
|
||||||
|
DOCKER_IMAGE=$4
|
||||||
|
|
||||||
shift 3
|
shift 4
|
||||||
COMMANDS=("$@")
|
COMMANDS=("$@")
|
||||||
if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then
|
if [ ${#COMMANDS[@]} -ne $NUM_NODES ]; then
|
||||||
echo "The number of commands must be equal to the number of nodes."
|
echo "The number of commands must be equal to the number of nodes."
|
||||||
@ -40,13 +41,40 @@ start_nodes() {
|
|||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
GPU_DEVICES+='"'
|
GPU_DEVICES+='"'
|
||||||
# echo "Starting node$node with GPU devices: $GPU_DEVICES"
|
|
||||||
docker run -d --gpus "$GPU_DEVICES" --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE tail -f /dev/null
|
# start the container in detached mode
|
||||||
|
# things to note:
|
||||||
|
# 1. --shm-size=10.24gb is required. don't use --ipc=host
|
||||||
|
# 2. pass HF_TOKEN to the container
|
||||||
|
# 3. map the huggingface cache directory to the container
|
||||||
|
# 3. assign ip addresses to the containers (head node: 192.168.10.10, worker nodes:
|
||||||
|
# starting from 192.168.10.11)
|
||||||
|
docker run -d --gpus "$GPU_DEVICES" --shm-size=10.24gb -e HF_TOKEN -v ~/.cache/huggingface:/root/.cache/huggingface --name node$node --network docker-net --ip 192.168.10.$((10 + $node)) --rm $DOCKER_IMAGE /bin/bash -c "tail -f /dev/null"
|
||||||
|
|
||||||
|
# organize containers into a ray cluster
|
||||||
|
if [ $node -eq 0 ]; then
|
||||||
|
# start the ray head node
|
||||||
|
docker exec -d node$node /bin/bash -c "ray start --head --port=6379 --block"
|
||||||
|
# wait for the head node to be ready
|
||||||
|
sleep 10
|
||||||
|
else
|
||||||
|
# start the ray worker nodes, and connect them to the head node
|
||||||
|
docker exec -d node$node /bin/bash -c "ray start --address=192.168.10.10:6379 --block"
|
||||||
|
fi
|
||||||
done
|
done
|
||||||
|
|
||||||
|
# wait for the cluster to be ready
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
# print the cluster status
|
||||||
|
docker exec node0 /bin/bash -c "ray status"
|
||||||
}
|
}
|
||||||
|
|
||||||
run_nodes() {
|
run_nodes() {
|
||||||
for node in $(seq 0 $(($NUM_NODES-1))); do
|
# important: iterate in reverse order to start the head node last
|
||||||
|
# we start the worker nodes first, in detached mode, and then start the head node
|
||||||
|
# in the foreground, so that the output of the head node is visible in the buildkite logs
|
||||||
|
for node in $(seq $(($NUM_NODES - 1)) -1 0); do
|
||||||
GPU_DEVICES='"device='
|
GPU_DEVICES='"device='
|
||||||
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
|
for node_gpu in $(seq 0 $(($NUM_GPUS - 1))); do
|
||||||
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
|
DEVICE_NUM=$(($node * $NUM_GPUS + $node_gpu))
|
||||||
@ -57,10 +85,10 @@ run_nodes() {
|
|||||||
done
|
done
|
||||||
GPU_DEVICES+='"'
|
GPU_DEVICES+='"'
|
||||||
echo "Running node$node with GPU devices: $GPU_DEVICES"
|
echo "Running node$node with GPU devices: $GPU_DEVICES"
|
||||||
if [ $node -lt $(($NUM_NODES - 1)) ]; then
|
if [ $node -ne 0 ]; then
|
||||||
docker exec -d node$node /bin/bash -c "${COMMANDS[$node]}"
|
docker exec -d node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
|
||||||
else
|
else
|
||||||
docker exec node$node /bin/bash -c "${COMMANDS[$node]}"
|
docker exec node$node /bin/bash -c "cd $WORKING_DIR ; ${COMMANDS[$node]}"
|
||||||
fi
|
fi
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|||||||
@ -68,6 +68,17 @@ steps:
|
|||||||
- pytest -v -s distributed/test_comm_ops.py
|
- pytest -v -s distributed/test_comm_ops.py
|
||||||
- pytest -v -s distributed/test_shm_broadcast.py
|
- pytest -v -s distributed/test_shm_broadcast.py
|
||||||
|
|
||||||
|
- label: 2 Node Tests (4 GPUs in total)
|
||||||
|
working_dir: "/vllm-workspace/tests"
|
||||||
|
num_gpus: 2
|
||||||
|
num_nodes: 2
|
||||||
|
commands:
|
||||||
|
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||||
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
|
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||||
|
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||||
|
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||||
|
|
||||||
- label: Distributed Tests (2 GPUs)
|
- label: Distributed Tests (2 GPUs)
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
@ -213,7 +224,10 @@ steps:
|
|||||||
|
|
||||||
- label: Tensorizer Test
|
- label: Tensorizer Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
command: apt-get install curl libsodium23 && pytest -v -s tensorizer_loader
|
commands:
|
||||||
|
- apt-get install curl libsodium23
|
||||||
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
- pytest -v -s tensorizer_loader
|
||||||
|
|
||||||
- label: Metrics Test
|
- label: Metrics Test
|
||||||
mirror_hardwares: [amd]
|
mirror_hardwares: [amd]
|
||||||
|
|||||||
@ -1,35 +1,26 @@
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
# using Ray for overall ease of process management, parallel requests,
|
|
||||||
# and debugging.
|
|
||||||
import ray
|
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
@pytest.fixture(scope="module")
|
"float16",
|
||||||
def server(ray_ctx):
|
"--max-model-len",
|
||||||
return RemoteOpenAIServer([
|
"2048",
|
||||||
"--model",
|
"--enforce-eager",
|
||||||
MODEL_NAME,
|
"--engine-use-ray"
|
||||||
# use half precision for speed and memory savings in CI environment
|
]) as remote_server:
|
||||||
"--dtype",
|
yield remote_server
|
||||||
"float16",
|
|
||||||
"--max-model-len",
|
|
||||||
"2048",
|
|
||||||
"--enforce-eager",
|
|
||||||
"--engine-use-ray"
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -2,11 +2,8 @@ import os
|
|||||||
|
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
# using Ray for overall ease of process management, parallel requests,
|
|
||||||
# and debugging.
|
|
||||||
import ray
|
|
||||||
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
|
|
||||||
@ -21,14 +18,7 @@ pytestmark = pytest.mark.asyncio
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
|
||||||
yield
|
|
||||||
ray.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def server(ray_ctx):
|
|
||||||
args = [
|
args = [
|
||||||
"--model",
|
"--model",
|
||||||
MODEL_NAME,
|
MODEL_NAME,
|
||||||
@ -50,7 +40,8 @@ def server(ray_ctx):
|
|||||||
args += [
|
args += [
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
]
|
]
|
||||||
return RemoteOpenAIServer(args, num_gpus=PP_SIZE * TP_SIZE)
|
with RemoteOpenAIServer(args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -10,3 +10,4 @@ test_result = all(
|
|||||||
|
|
||||||
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
expected = os.environ.get("VLLM_TEST_SAME_HOST", "1") == "1"
|
||||||
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
assert test_result == expected, f"Expected {expected}, got {test_result}"
|
||||||
|
print("Same node test passed!")
|
||||||
|
|||||||
@ -6,15 +6,12 @@ from typing import List
|
|||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
# using Ray for overall ease of process management, parallel requests,
|
|
||||||
# and debugging.
|
|
||||||
import ray
|
|
||||||
import torch
|
import torch
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@ -29,35 +26,29 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server(zephyr_lora_files):
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
@pytest.fixture(scope="module")
|
"bfloat16",
|
||||||
def server(zephyr_lora_files, ray_ctx):
|
"--max-model-len",
|
||||||
return RemoteOpenAIServer([
|
"8192",
|
||||||
"--model",
|
"--enforce-eager",
|
||||||
MODEL_NAME,
|
# lora config below
|
||||||
# use half precision for speed and memory savings in CI environment
|
"--enable-lora",
|
||||||
"--dtype",
|
"--lora-modules",
|
||||||
"bfloat16",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
"--max-model-len",
|
f"zephyr-lora2={zephyr_lora_files}",
|
||||||
"8192",
|
"--max-lora-rank",
|
||||||
"--enforce-eager",
|
"64",
|
||||||
# lora config below
|
"--max-cpu-loras",
|
||||||
"--enable-lora",
|
"2",
|
||||||
"--lora-modules",
|
"--max-num-seqs",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
"128",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
]) as remote_server:
|
||||||
"--max-lora-rank",
|
yield remote_server
|
||||||
"64",
|
|
||||||
"--max-cpu-loras",
|
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -6,9 +6,6 @@ from typing import List
|
|||||||
import jsonschema
|
import jsonschema
|
||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
# using Ray for overall ease of process management, parallel requests,
|
|
||||||
# and debugging.
|
|
||||||
import ray
|
|
||||||
import requests
|
import requests
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@ -16,7 +13,7 @@ from openai import BadRequestError
|
|||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@ -31,35 +28,29 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server(zephyr_lora_files):
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
@pytest.fixture(scope="module")
|
"bfloat16",
|
||||||
def server(zephyr_lora_files, ray_ctx):
|
"--max-model-len",
|
||||||
return RemoteOpenAIServer([
|
"8192",
|
||||||
"--model",
|
"--enforce-eager",
|
||||||
MODEL_NAME,
|
# lora config below
|
||||||
# use half precision for speed and memory savings in CI environment
|
"--enable-lora",
|
||||||
"--dtype",
|
"--lora-modules",
|
||||||
"bfloat16",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
"--max-model-len",
|
f"zephyr-lora2={zephyr_lora_files}",
|
||||||
"8192",
|
"--max-lora-rank",
|
||||||
"--enforce-eager",
|
"64",
|
||||||
# lora config below
|
"--max-cpu-loras",
|
||||||
"--enable-lora",
|
"2",
|
||||||
"--lora-modules",
|
"--max-num-seqs",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
"128",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
]) as remote_server:
|
||||||
"--max-lora-rank",
|
yield remote_server
|
||||||
"64",
|
|
||||||
"--max-cpu-loras",
|
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -3,33 +3,26 @@ import base64
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def embedding_server():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
EMBEDDING_MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
@pytest.fixture(scope="module")
|
"bfloat16",
|
||||||
def embedding_server(ray_ctx):
|
"--enforce-eager",
|
||||||
return RemoteOpenAIServer([
|
"--max-model-len",
|
||||||
"--model",
|
"8192",
|
||||||
EMBEDDING_MODEL_NAME,
|
"--enforce-eager",
|
||||||
# use half precision for speed and memory savings in CI environment
|
]) as remote_server:
|
||||||
"--dtype",
|
yield remote_server
|
||||||
"bfloat16",
|
|
||||||
"--enforce-eager",
|
|
||||||
"--max-model-len",
|
|
||||||
"8192",
|
|
||||||
"--enforce-eager",
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
import openai # use the official client for correctness check
|
import openai # use the official client for correctness check
|
||||||
import pytest
|
import pytest
|
||||||
# using Ray for overall ease of process management, parallel requests,
|
|
||||||
# and debugging.
|
|
||||||
import ray
|
|
||||||
# downloading lora to test lora requests
|
# downloading lora to test lora requests
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
from ...utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# any model with a chat template should work here
|
# any model with a chat template should work here
|
||||||
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
||||||
@ -21,35 +18,29 @@ def zephyr_lora_files():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server(zephyr_lora_files):
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
MODEL_NAME,
|
||||||
|
# use half precision for speed and memory savings in CI environment
|
||||||
|
"--dtype",
|
||||||
@pytest.fixture(scope="module")
|
"bfloat16",
|
||||||
def server(zephyr_lora_files, ray_ctx):
|
"--max-model-len",
|
||||||
return RemoteOpenAIServer([
|
"8192",
|
||||||
"--model",
|
"--enforce-eager",
|
||||||
MODEL_NAME,
|
# lora config below
|
||||||
# use half precision for speed and memory savings in CI environment
|
"--enable-lora",
|
||||||
"--dtype",
|
"--lora-modules",
|
||||||
"bfloat16",
|
f"zephyr-lora={zephyr_lora_files}",
|
||||||
"--max-model-len",
|
f"zephyr-lora2={zephyr_lora_files}",
|
||||||
"8192",
|
"--max-lora-rank",
|
||||||
"--enforce-eager",
|
"64",
|
||||||
# lora config below
|
"--max-cpu-loras",
|
||||||
"--enable-lora",
|
"2",
|
||||||
"--lora-modules",
|
"--max-num-seqs",
|
||||||
f"zephyr-lora={zephyr_lora_files}",
|
"128",
|
||||||
f"zephyr-lora2={zephyr_lora_files}",
|
]) as remote_server:
|
||||||
"--max-lora-rank",
|
yield remote_server
|
||||||
"64",
|
|
||||||
"--max-cpu-loras",
|
|
||||||
"2",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from typing import Dict, List
|
|||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import ray
|
|
||||||
|
|
||||||
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
from vllm.multimodal.utils import ImageFetchAiohttp, encode_image_base64
|
||||||
|
|
||||||
@ -23,25 +22,19 @@ TEST_IMAGE_URLS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def ray_ctx():
|
def server():
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer([
|
||||||
yield
|
"--model",
|
||||||
ray.shutdown()
|
MODEL_NAME,
|
||||||
|
"--dtype",
|
||||||
|
"bfloat16",
|
||||||
@pytest.fixture(scope="module")
|
"--max-model-len",
|
||||||
def server(ray_ctx):
|
"4096",
|
||||||
return RemoteOpenAIServer([
|
"--enforce-eager",
|
||||||
"--model",
|
"--chat-template",
|
||||||
MODEL_NAME,
|
str(LLAVA_CHAT_TEMPLATE),
|
||||||
"--dtype",
|
]) as remote_server:
|
||||||
"bfloat16",
|
yield remote_server
|
||||||
"--max-model-len",
|
|
||||||
"4096",
|
|
||||||
"--enforce-eager",
|
|
||||||
"--chat-template",
|
|
||||||
str(LLAVA_CHAT_TEMPLATE),
|
|
||||||
])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
|
|||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
import ray
|
|
||||||
import torch
|
import torch
|
||||||
from tensorizer import EncryptionParams
|
from tensorizer import EncryptionParams
|
||||||
|
|
||||||
@ -22,7 +21,7 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
|||||||
tensorize_vllm_model)
|
tensorize_vllm_model)
|
||||||
|
|
||||||
from ..conftest import VllmRunner, cleanup
|
from ..conftest import VllmRunner, cleanup
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
# yapf conflicts with isort for this docstring
|
# yapf conflicts with isort for this docstring
|
||||||
|
|
||||||
@ -220,23 +219,21 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
json.dumps(model_loader_extra_config),
|
json.dumps(model_loader_extra_config),
|
||||||
]
|
]
|
||||||
|
|
||||||
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
with RemoteOpenAIServer(openai_args) as server:
|
||||||
|
print("Server ready.")
|
||||||
|
|
||||||
server = RemoteOpenAIServer(openai_args)
|
client = server.get_client()
|
||||||
print("Server ready.")
|
completion = client.completions.create(model=model_ref,
|
||||||
|
prompt="Hello, my name is",
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0)
|
||||||
|
|
||||||
client = server.get_client()
|
assert completion.id is not None
|
||||||
completion = client.completions.create(model=model_ref,
|
assert len(completion.choices) == 1
|
||||||
prompt="Hello, my name is",
|
assert len(completion.choices[0].text) >= 5
|
||||||
max_tokens=5,
|
assert completion.choices[0].finish_reason == "length"
|
||||||
temperature=0.0)
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
assert completion.id is not None
|
|
||||||
assert len(completion.choices) == 1
|
|
||||||
assert len(completion.choices[0].text) >= 5
|
|
||||||
assert completion.choices[0].finish_reason == "length"
|
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
||||||
|
|
||||||
|
|
||||||
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
def test_raise_value_error_on_invalid_load_format(vllm_runner):
|
||||||
@ -282,7 +279,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
|||||||
base_model.model.llm_engine.model_executor.shutdown()
|
base_model.model.llm_engine.model_executor.shutdown()
|
||||||
del base_model
|
del base_model
|
||||||
cleanup()
|
cleanup()
|
||||||
ray.shutdown()
|
|
||||||
|
|
||||||
# load model with two shards and serialize with encryption
|
# load model with two shards and serialize with encryption
|
||||||
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
model_path = str(tmp_path / (model_ref + "-%02d.tensors"))
|
||||||
@ -305,7 +301,6 @@ def test_deserialized_encrypted_vllm_model_with_tp_has_same_outputs(vllm_runner,
|
|||||||
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
|
assert os.path.isfile(model_path % 0), "Serialization subprocess failed"
|
||||||
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
|
assert os.path.isfile(model_path % 1), "Serialization subprocess failed"
|
||||||
cleanup()
|
cleanup()
|
||||||
ray.shutdown()
|
|
||||||
|
|
||||||
loaded_vllm_model = vllm_runner(
|
loaded_vllm_model = vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
|
|||||||
@ -49,53 +49,7 @@ class RemoteOpenAIServer:
|
|||||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||||
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds
|
||||||
|
|
||||||
class _RemoteRunner:
|
def __init__(self, cli_args: List[str], *, auto_port: bool = True) -> None:
|
||||||
|
|
||||||
def __init__(self, cli_args: List[str], *, wait_url: str,
|
|
||||||
wait_timeout: float) -> None:
|
|
||||||
env = os.environ.copy()
|
|
||||||
env["PYTHONUNBUFFERED"] = "1"
|
|
||||||
self.proc = subprocess.Popen(
|
|
||||||
[
|
|
||||||
sys.executable, "-m", "vllm.entrypoints.openai.api_server",
|
|
||||||
*cli_args
|
|
||||||
],
|
|
||||||
env=env,
|
|
||||||
stdout=sys.stdout,
|
|
||||||
stderr=sys.stderr,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._wait_for_server(url=wait_url, timeout=wait_timeout)
|
|
||||||
|
|
||||||
def ready(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _wait_for_server(self, *, url: str, timeout: float):
|
|
||||||
# run health check
|
|
||||||
start = time.time()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if requests.get(url).status_code == 200:
|
|
||||||
break
|
|
||||||
except Exception as err:
|
|
||||||
if self.proc.poll() is not None:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Server exited unexpectedly.") from err
|
|
||||||
|
|
||||||
time.sleep(0.5)
|
|
||||||
if time.time() - start > timeout:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Server failed to start in time.") from err
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if hasattr(self, "proc"):
|
|
||||||
self.proc.terminate()
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
cli_args: List[str],
|
|
||||||
*,
|
|
||||||
auto_port: bool = True,
|
|
||||||
num_gpus: int = 1) -> None:
|
|
||||||
if auto_port:
|
if auto_port:
|
||||||
if "-p" in cli_args or "--port" in cli_args:
|
if "-p" in cli_args or "--port" in cli_args:
|
||||||
raise ValueError("You have manually specified the port"
|
raise ValueError("You have manually specified the port"
|
||||||
@ -108,13 +62,41 @@ class RemoteOpenAIServer:
|
|||||||
self.host = str(args.host or 'localhost')
|
self.host = str(args.host or 'localhost')
|
||||||
self.port = int(args.port)
|
self.port = int(args.port)
|
||||||
|
|
||||||
self._runner = ray.remote(num_gpus=num_gpus)(
|
env = os.environ.copy()
|
||||||
self._RemoteRunner).remote(
|
# the current process might initialize cuda,
|
||||||
cli_args,
|
# to be safe, we should use spawn method
|
||||||
wait_url=self.url_for("health"),
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||||
wait_timeout=self.MAX_SERVER_START_WAIT_S)
|
self.proc = subprocess.Popen(
|
||||||
|
[sys.executable, "-m", "vllm.entrypoints.openai.api_server"] +
|
||||||
|
cli_args,
|
||||||
|
env=env,
|
||||||
|
stdout=sys.stdout,
|
||||||
|
stderr=sys.stderr)
|
||||||
|
self._wait_for_server(url=self.url_for("health"),
|
||||||
|
timeout=self.MAX_SERVER_START_WAIT_S)
|
||||||
|
|
||||||
self._wait_until_ready()
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
self.proc.terminate()
|
||||||
|
|
||||||
|
def _wait_for_server(self, *, url: str, timeout: float):
|
||||||
|
# run health check
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if requests.get(url).status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception as err:
|
||||||
|
result = self.proc.poll()
|
||||||
|
if result is not None and result != 0:
|
||||||
|
raise RuntimeError("Server exited unexpectedly.") from err
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
if time.time() - start > timeout:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Server failed to start in time.") from err
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def url_root(self) -> str:
|
def url_root(self) -> str:
|
||||||
@ -123,9 +105,6 @@ class RemoteOpenAIServer:
|
|||||||
def url_for(self, *parts: str) -> str:
|
def url_for(self, *parts: str) -> str:
|
||||||
return self.url_root + "/" + "/".join(parts)
|
return self.url_root + "/" + "/".join(parts)
|
||||||
|
|
||||||
def _wait_until_ready(self) -> None:
|
|
||||||
ray.get(self._runner.ready.remote())
|
|
||||||
|
|
||||||
def get_client(self):
|
def get_client(self):
|
||||||
return openai.OpenAI(
|
return openai.OpenAI(
|
||||||
base_url=self.url_for("v1"),
|
base_url=self.url_for("v1"),
|
||||||
|
|||||||
@ -224,16 +224,13 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
# broadcasted to.
|
# broadcasted to.
|
||||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||||
|
|
||||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
for idx, rank in enumerate(worker_ranks[1:]):
|
||||||
for tp_rank in range(self.parallel_config.tensor_parallel_size):
|
# We need to skip the driver worker, which we
|
||||||
rank = (pp_rank *
|
# do by skipping worker_ranks[0] which is always 0.
|
||||||
self.parallel_config.tensor_parallel_size) + tp_rank
|
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||||
if rank == 0:
|
self.tp_driver_workers.append(self.workers[idx])
|
||||||
pass
|
else:
|
||||||
elif rank % self.parallel_config.tensor_parallel_size == 0:
|
self.non_driver_workers.append(self.workers[idx])
|
||||||
self.tp_driver_workers.append(self.workers[rank - 1])
|
|
||||||
else:
|
|
||||||
self.non_driver_workers.append(self.workers[rank - 1])
|
|
||||||
|
|
||||||
def _driver_execute_model(
|
def _driver_execute_model(
|
||||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user