Strip cuda name from torch version

This commit is contained in:
Pierce Freeman 2023-06-02 18:02:24 -07:00
parent 5e4699782a
commit 9fc9820a5b

View File

@ -51,11 +51,12 @@ class CustomInstallCommand(install):
# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_version = torch.__version__
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'