Guessing wheel URL
This commit is contained in:
parent
e1faefce9d
commit
0e7769c813
10
setup.py
10
setup.py
@ -47,18 +47,22 @@ class CustomInstallCommand(install):
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
_, cuda_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_version = torch.__version__
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
flash_version = get_package_version()
|
||||
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
|
||||
|
||||
# Determine wheel URL based on CUDA version, torch version, python version and OS
|
||||
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'
|
||||
wheel_url = BASE_WHEEL_URL.format(
|
||||
tag_name=f"v{flash_version}",
|
||||
#tag_name=f"v{flash_version}",
|
||||
# HACK
|
||||
tag_name=f"v0.0.3",
|
||||
wheel_name=wheel_filename
|
||||
)
|
||||
print("Guessing wheel URL: ", wheel_url)
|
||||
|
||||
try:
|
||||
urllib.request.urlretrieve(wheel_url, wheel_filename)
|
||||
@ -70,8 +74,6 @@ class CustomInstallCommand(install):
|
||||
#install.run(self)
|
||||
raise ValueError
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user