diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d33fba9..532d9ff 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -133,9 +133,13 @@ jobs: print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - # --no-deps because we can't install old versions of pytorch-triton - pip install typing-extensions jinja2 - pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then + # --no-deps because we can't install old versions of pytorch-triton + pip install typing-extensions jinja2 + pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + else + pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + fi else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 6dca37f..729cb16 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.1" +__version__ = "2.5.1.post1" from flash_attn.flash_attn_interface import ( flash_attn_func,