Refactor TPU requirements file and pin build dependencies (#10010)
Signed-off-by: Richard Liu <ricliu@google.com>
This commit is contained in:
parent
5952d81139
commit
cd34029e91
@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
|
|||||||
git \
|
git \
|
||||||
ffmpeg libsm6 libxext6 libgl1
|
ffmpeg libsm6 libxext6 libgl1
|
||||||
|
|
||||||
# Install the TPU and Pallas dependencies.
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
||||||
|
|
||||||
# Build vLLM.
|
# Build vLLM.
|
||||||
COPY . .
|
COPY . .
|
||||||
ARG GIT_REPO_CHECK=0
|
ARG GIT_REPO_CHECK=0
|
||||||
@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu"
|
|||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||||
--mount=type=bind,source=.git,target=.git \
|
--mount=type=bind,source=.git,target=.git \
|
||||||
python3 -m pip install \
|
python3 -m pip install \
|
||||||
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
|
|
||||||
-r requirements-tpu.txt
|
-r requirements-tpu.txt
|
||||||
RUN python3 setup.py develop
|
RUN python3 setup.py develop
|
||||||
|
|
||||||
|
|||||||
@ -119,28 +119,20 @@ Uninstall the existing `torch` and `torch_xla` packages:
|
|||||||
|
|
||||||
pip uninstall torch torch-xla -y
|
pip uninstall torch torch-xla -y
|
||||||
|
|
||||||
Install `torch` and `torch_xla`
|
Install build dependencies:
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
|
|
||||||
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
|
|
||||||
|
|
||||||
Install JAX and Pallas:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
||||||
pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
||||||
|
|
||||||
Install other build dependencies:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
pip install -r requirements-tpu.txt
|
pip install -r requirements-tpu.txt
|
||||||
VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
|
||||||
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
|
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
|
||||||
|
|
||||||
|
Run the setup script:
|
||||||
|
|
||||||
|
.. code-block:: bash
|
||||||
|
|
||||||
|
VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
||||||
|
|
||||||
|
|
||||||
Provision Cloud TPUs with GKE
|
Provision Cloud TPUs with GKE
|
||||||
-----------------------------
|
-----------------------------
|
||||||
|
|
||||||
@ -168,45 +160,6 @@ Run the Docker image with the following command:
|
|||||||
$ # Make sure to add `--privileged --net host --shm-size=16G`.
|
$ # Make sure to add `--privileged --net host --shm-size=16G`.
|
||||||
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
|
||||||
|
|
||||||
|
|
||||||
.. _build_from_source_tpu:
|
|
||||||
|
|
||||||
Build from source
|
|
||||||
-----------------
|
|
||||||
|
|
||||||
You can also build and install the TPU backend from source.
|
|
||||||
|
|
||||||
First, install the dependencies:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ # (Recommended) Create a new conda environment.
|
|
||||||
$ conda create -n myenv python=3.10 -y
|
|
||||||
$ conda activate myenv
|
|
||||||
|
|
||||||
$ # Clean up the existing torch and torch-xla packages.
|
|
||||||
$ pip uninstall torch torch-xla -y
|
|
||||||
|
|
||||||
$ # Install PyTorch and PyTorch XLA.
|
|
||||||
$ export DATE="20241017"
|
|
||||||
$ export TORCH_VERSION="2.6.0"
|
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
|
||||||
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
|
|
||||||
|
|
||||||
$ # Install JAX and Pallas.
|
|
||||||
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
|
||||||
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
|
||||||
|
|
||||||
$ # Install other build dependencies.
|
|
||||||
$ pip install -r requirements-tpu.txt
|
|
||||||
|
|
||||||
|
|
||||||
Next, build vLLM from source. This will only take a few seconds:
|
|
||||||
|
|
||||||
.. code-block:: console
|
|
||||||
|
|
||||||
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
|
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
|
Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
|
||||||
|
|||||||
@ -2,6 +2,22 @@
|
|||||||
-r requirements-common.txt
|
-r requirements-common.txt
|
||||||
|
|
||||||
# Dependencies for TPU
|
# Dependencies for TPU
|
||||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
cmake>=3.26
|
||||||
# You can install the dependencies in Dockerfile.tpu.
|
ninja
|
||||||
|
packaging
|
||||||
|
setuptools-scm>=8
|
||||||
|
wheel
|
||||||
|
jinja2
|
||||||
ray[default]
|
ray[default]
|
||||||
|
|
||||||
|
# Install torch_xla
|
||||||
|
--pre
|
||||||
|
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||||
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
|
torch==2.6.0.dev20241028+cpu
|
||||||
|
torchvision==0.20.0.dev20241028+cpu
|
||||||
|
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
|
||||||
|
jaxlib==0.4.32.dev20240829
|
||||||
|
jax==0.4.32.dev20240829
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user