From 348897af3112c4e9f6fdc3cd1a7093c98b10e705 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Oct 2023 11:27:17 -0700 Subject: [PATCH] Fix PyTorch version to 2.0.1 in workflow (#1377) --- .github/workflows/publish.yml | 5 +++-- .github/workflows/scripts/pytorch-install.sh | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 770eded5..e4210734 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,6 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] + pytorch-version: ['2.0.1'] cuda-version: ['11.8'] # Github runner can't build anything older than 11.8 steps: @@ -69,9 +70,9 @@ jobs: run: | bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - name: Install PyTorch-cu${{ matrix.cuda-version }} + - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} run: | - bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} + bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - name: Build wheel shell: bash diff --git a/.github/workflows/scripts/pytorch-install.sh b/.github/workflows/scripts/pytorch-install.sh index 3e20d9a8..610f08fd 100644 --- a/.github/workflows/scripts/pytorch-install.sh +++ b/.github/workflows/scripts/pytorch-install.sh @@ -1,11 +1,12 @@ #!/bin/bash python_executable=python$1 -cuda_version=$2 +pytorch_version=$2 +cuda_version=$3 # Install torch $python_executable -m pip install numpy pyyaml scipy ipython mkl mkl-include ninja cython typing pandas typing-extensions dataclasses setuptools && conda clean -ya -$python_executable -m pip install torch -f https://download.pytorch.org/whl/cu${cuda_version//./}/torch_stable.html +$python_executable -m pip install torch==${pytorch_version}+cu${cuda_version//./} --index-url https://download.pytorch.org/whl/cu${cuda_version//./} # Print version information $python_executable --version