diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 652f04ad..1b89b892 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -55,16 +55,22 @@ RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" # Install torch == 2.4.0 on ROCm RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-5.7"*) \ - pip uninstall -y torch \ - && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \ *"rocm-6.0"*) \ - pip uninstall -y torch \ - && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \ *"rocm-6.1"*) \ - pip uninstall -y torch \ - && pip install --no-cache-dir --pre torch==2.4.0.dev20240612 \ + pip uninstall -y torch torchaudio torchvision \ + && pip install --no-cache-dir --pre \ + torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \ + torchvision==0.19.0.dev20240612 \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ *) ;; esac diff --git a/tests/entrypoints/test_openai_chat.py b/tests/entrypoints/test_openai_chat.py index 1c46a511..52e64717 100644 --- a/tests/entrypoints/test_openai_chat.py +++ b/tests/entrypoints/test_openai_chat.py @@ -14,7 +14,7 @@ import torch from huggingface_hub import snapshot_download from openai import BadRequestError -from ..utils import VLLM_PATH, RemoteOpenAIServer +from ..utils import RemoteOpenAIServer # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" @@ -79,7 +79,7 @@ def zephyr_lora_files(): @pytest.fixture(scope="module") def ray_ctx(): - ray.init(runtime_env={"working_dir": VLLM_PATH}) + ray.init() yield ray.shutdown()