Allow fallback install
This commit is contained in:
parent
dab99053e4
commit
5e4699782a
1
.github/workflows/publish.yml
vendored
1
.github/workflows/publish.yml
vendored
@ -109,6 +109,7 @@ jobs:
|
||||
|
||||
- name: Build wheel
|
||||
run: |
|
||||
export FLASH_ATTENTION_FORCE_BUILD="TRUE"
|
||||
export FORCE_CUDA="1"
|
||||
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
|
||||
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
5
setup.py
5
setup.py
@ -44,6 +44,9 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
|
||||
|
||||
class CustomInstallCommand(install):
|
||||
def run(self):
|
||||
if os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE":
|
||||
return install.run(self)
|
||||
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
@ -59,7 +62,7 @@ class CustomInstallCommand(install):
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
#tag_name=f"v{flash_version}",
|
||||
# HACK
|
||||
tag_name=f"v0.0.3",
|
||||
tag_name=f"v0.0.5",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user