diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a9bd229..a0244f8 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -126,7 +126,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: upload_url: ${{ steps.get_current_release.outputs.upload_url }} - asset_path: ./${{env.wheel_name}} + asset_path: ./dist/${{env.wheel_name}} asset_name: ${{env.wheel_name}} asset_content_type: application/* diff --git a/setup.py b/setup.py index a5b63b1..91a37ce 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,8 @@ from packaging.version import parse, Version from setuptools import setup, find_packages import subprocess -import urllib +import urllib.request +import urllib.error import torch from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME @@ -43,8 +44,10 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down class CustomInstallCommand(install): def run(self): + raise_if_cuda_home_none("flash_attn") + # Determine the version numbers that will be used to determine the correct wheel - _, cuda_version = get_cuda_bare_metal_version() + _, cuda_version = get_cuda_bare_metal_version(CUDA_HOME) torch_version = torch.__version__ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() @@ -64,7 +67,10 @@ class CustomInstallCommand(install): except urllib.error.HTTPError: print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source - install.run(self) + #install.run(self) + raise ValueError + + raise ValueError def get_cuda_bare_metal_version(cuda_dir):