Refactor TPU requirements file and pin build dependencies (#10010)

Signed-off-by: Richard Liu <ricliu@google.com>
This commit is contained in:
Richard Liu 2024-11-05 08:48:44 -08:00 committed by GitHub
parent 5952d81139
commit cd34029e91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 26 additions and 64 deletions

View File

@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
git \ git \
ffmpeg libsm6 libxext6 libgl1 ffmpeg libsm6 libxext6 libgl1
# Install the TPU and Pallas dependencies.
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Build vLLM. # Build vLLM.
COPY . . COPY . .
ARG GIT_REPO_CHECK=0 ARG GIT_REPO_CHECK=0
@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \ --mount=type=bind,source=.git,target=.git \
python3 -m pip install \ python3 -m pip install \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-tpu.txt -r requirements-tpu.txt
RUN python3 setup.py develop RUN python3 setup.py develop

View File

@ -119,28 +119,20 @@ Uninstall the existing `torch` and `torch_xla` packages:
pip uninstall torch torch-xla -y pip uninstall torch torch-xla -y
Install `torch` and `torch_xla` Install build dependencies:
.. code-block:: bash
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
Install JAX and Pallas:
.. code-block:: bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Install other build dependencies:
.. code-block:: bash .. code-block:: bash
pip install -r requirements-tpu.txt pip install -r requirements-tpu.txt
VLLM_TARGET_DEVICE="tpu" python setup.py develop
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Run the setup script:
.. code-block:: bash
VLLM_TARGET_DEVICE="tpu" python setup.py develop
Provision Cloud TPUs with GKE Provision Cloud TPUs with GKE
----------------------------- -----------------------------
@ -168,45 +160,6 @@ Run the Docker image with the following command:
$ # Make sure to add `--privileged --net host --shm-size=16G`. $ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu $ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:
Build from source
-----------------
You can also build and install the TPU backend from source.
First, install the dependencies:
.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="20241017"
$ export TORCH_VERSION="2.6.0"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install -r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:
.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note:: .. note::
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape. Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.

View File

@ -2,6 +2,22 @@
-r requirements-common.txt -r requirements-common.txt
# Dependencies for TPU # Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA. cmake>=3.26
# You can install the dependencies in Dockerfile.tpu. ninja
packaging
setuptools-scm>=8
wheel
jinja2
ray[default] ray[default]
# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241028+cpu
torchvision==0.20.0.dev20241028+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829