From 0e7769c813fcd2b04882a9cd7e13945002a903d3 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Fri, 2 Jun 2023 14:41:07 -0700 Subject: [PATCH] Guessing wheel URL --- setup.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 91a37ce..7581d74 100644 --- a/setup.py +++ b/setup.py @@ -47,18 +47,22 @@ class CustomInstallCommand(install): raise_if_cuda_home_none("flash_attn") # Determine the version numbers that will be used to determine the correct wheel - _, cuda_version = get_cuda_bare_metal_version(CUDA_HOME) + _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_version = torch.__version__ python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() + cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl' wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", + #tag_name=f"v{flash_version}", + # HACK + tag_name=f"v0.0.3", wheel_name=wheel_filename ) + print("Guessing wheel URL: ", wheel_url) try: urllib.request.urlretrieve(wheel_url, wheel_filename) @@ -70,8 +74,6 @@ class CustomInstallCommand(install): #install.run(self) raise ValueError - raise ValueError - def get_cuda_bare_metal_version(cuda_dir): raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)