Raise cuda error on build
This commit is contained in:
parent
add4f0bc42
commit
e1faefce9d
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@ -126,7 +126,7 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
with:
|
||||
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
|
||||
asset_path: ./${{env.wheel_name}}
|
||||
asset_path: ./dist/${{env.wheel_name}}
|
||||
asset_name: ${{env.wheel_name}}
|
||||
asset_content_type: application/*
|
||||
|
||||
|
||||
12
setup.py
12
setup.py
@ -10,7 +10,8 @@ from packaging.version import parse, Version
|
||||
from setuptools import setup, find_packages
|
||||
import subprocess
|
||||
|
||||
import urllib
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
@ -43,8 +44,10 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
|
||||
|
||||
class CustomInstallCommand(install):
|
||||
def run(self):
|
||||
raise_if_cuda_home_none("flash_attn")
|
||||
|
||||
# Determine the version numbers that will be used to determine the correct wheel
|
||||
_, cuda_version = get_cuda_bare_metal_version()
|
||||
_, cuda_version = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
torch_version = torch.__version__
|
||||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
|
||||
platform_name = get_platform()
|
||||
@ -64,7 +67,10 @@ class CustomInstallCommand(install):
|
||||
except urllib.error.HTTPError:
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
install.run(self)
|
||||
#install.run(self)
|
||||
raise ValueError
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
def get_cuda_bare_metal_version(cuda_dir):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user