From b0e09d7cd371eded41f7c1e057caf1593c27ba55 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Mon, 18 Nov 2024 15:06:32 +0100 Subject: [PATCH] Fix `cutlass` python library with cuda `12.6.2.post1` (#1942) * Fix `cutlass` python library with cuda `12.6.2.post1` Previously we had this error: ``` File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] ^^^^^^ ValueError: invalid literal for int() with base 10: 'post1' ``` * Update sm90_utils.py * Update generator.py * Update python/cutlass_library/generator.py Co-authored-by: Jack Kosaian * Update python/cutlass_library/sm90_utils.py Co-authored-by: Jack Kosaian --------- Co-authored-by: Jack Kosaian --- python/cutlass/backend/operation.py | 2 +- python/cutlass_library/generator.py | 2 +- python/cutlass_library/sm90_utils.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 568c1f69..a73cef68 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -36,7 +36,7 @@ from cuda import __version__, cuda from cutlass.backend.utils.device import device_cc -_version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] +_version_splits = [int(x) for x in __version__.split("rc")[0].split(".post")[0].split(".")] _supports_cluster_launch = None diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 85fdbb8e..e6a9f9e8 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -103,7 +103,7 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): # Update cuda_version based on parsed string if semantic_ver_string != '': - for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]): + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): if i < len(cuda_version): cuda_version[i] = x else: diff --git a/python/cutlass_library/sm90_utils.py b/python/cutlass_library/sm90_utils.py index f4abd94c..08fcd547 100644 --- a/python/cutlass_library/sm90_utils.py +++ b/python/cutlass_library/sm90_utils.py @@ -61,7 +61,7 @@ def CudaToolkitVersionSatisfies(semantic_ver_string, major, minor, patch = 0): # Update cuda_version based on parsed string if semantic_ver_string != '': - for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')]): + for i, x in enumerate([int(x) for x in semantic_ver_string.split('.')[:3]]): if i < len(cuda_version): cuda_version[i] = x else: