From 9fc9820a5bf0eb851b79388908f43a70affbe296 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Fri, 2 Jun 2023 18:02:24 -0700 Subject: [PATCH] Strip cuda name from torch version --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e0fcddd..ff718ef 100644 --- a/setup.py +++ b/setup.py @@ -51,11 +51,12 @@ class CustomInstallCommand(install): # Determine the version numbers that will be used to determine the correct wheel _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - torch_version = torch.__version__ + torch_version_raw = parse(torch.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}.{torch_version_raw.micro}" # Determine wheel URL based on CUDA version, torch version, python version and OS wheel_filename = f'flash_attn-{flash_version}+cu{cuda_version}torch{torch_version}-{python_version}-{python_version}-{platform_name}.whl'