Raise cuda error on build
This commit is contained in:
parent
add4f0bc42
commit
e1faefce9d
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@ -126,7 +126,7 @@ jobs:
|
|||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
with:
|
with:
|
||||||
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||||
asset_path: ./${{env.wheel_name}}
|
asset_path: ./dist/${{env.wheel_name}}
|
||||||
asset_name: ${{env.wheel_name}}
|
asset_name: ${{env.wheel_name}}
|
||||||
asset_content_type: application/*
|
asset_content_type: application/*
|
||||||
|
|
||||||
|
|||||||
12
setup.py
12
setup.py
@ -10,7 +10,8 @@ from packaging.version import parse, Version
|
|||||||
from setuptools import setup, find_packages
|
from setuptools import setup, find_packages
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
import urllib
|
import urllib.request
|
||||||
|
import urllib.error
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||||
|
|
||||||
@ -43,8 +44,10 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
|
|||||||
|
|
||||||
class CustomInstallCommand(install):
|
class CustomInstallCommand(install):
|
||||||
def run(self):
|
def run(self):
|
||||||
|
raise_if_cuda_home_none("flash_attn")
|
||||||
|
|
||||||
# Determine the version numbers that will be used to determine the correct wheel
|
# Determine the version numbers that will be used to determine the correct wheel
|
||||||
_, cuda_version = get_cuda_bare_metal_version()
|
_, cuda_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||||
torch_version = torch.__version__
|
torch_version = torch.__version__
|
||||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||||
platform_name = get_platform()
|
platform_name = get_platform()
|
||||||
@ -64,7 +67,10 @@ class CustomInstallCommand(install):
|
|||||||
except urllib.error.HTTPError:
|
except urllib.error.HTTPError:
|
||||||
print("Precompiled wheel not found. Building from source...")
|
print("Precompiled wheel not found. Building from source...")
|
||||||
# If the wheel could not be downloaded, build from source
|
# If the wheel could not be downloaded, build from source
|
||||||
install.run(self)
|
#install.run(self)
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_bare_metal_version(cuda_dir):
|
def get_cuda_bare_metal_version(cuda_dir):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user