diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index e69de29..68cdeee 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -0,0 +1 @@ +__version__ = "1.0.5" diff --git a/setup.py b/setup.py index 4677bb0..7597ea3 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,8 @@ import sys import warnings import os +import re +import ast from pathlib import Path from packaging.version import parse, Version @@ -160,9 +162,19 @@ ext_modules.append( ) ) +def get_package_version(): + with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + setup( name="flash_attn", - version="1.0.5", + version=get_package_version(), packages=find_packages( exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) ),