diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4f62194..dad5d7d 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -109,6 +109,7 @@ jobs: - name: Build wheel run: | + export FLASH_ATTENTION_FORCE_BUILD="TRUE" export FORCE_CUDA="1" export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH diff --git a/setup.py b/setup.py index 7581d74..e0fcddd 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,9 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down class CustomInstallCommand(install): def run(self): + if os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE": + return install.run(self) + raise_if_cuda_home_none("flash_attn") # Determine the version numbers that will be used to determine the correct wheel @@ -59,7 +62,7 @@ class CustomInstallCommand(install): wheel_url = BASE_WHEEL_URL.format( #tag_name=f"v{flash_version}", # HACK - tag_name=f"v0.0.3", + tag_name=f"v0.0.5", wheel_name=wheel_filename ) print("Guessing wheel URL: ", wheel_url)