Raise cuda error on build

This commit is contained in:
Pierce Freeman 2023-06-02 13:20:39 -07:00
parent add4f0bc42
commit e1faefce9d
2 changed files with 10 additions and 4 deletions

View File

@ -126,7 +126,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with: with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }} upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./${{env.wheel_name}} asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}} asset_name: ${{env.wheel_name}}
asset_content_type: application/* asset_content_type: application/*

View File

@ -10,7 +10,8 @@ from packaging.version import parse, Version
from setuptools import setup, find_packages from setuptools import setup, find_packages
import subprocess import subprocess
import urllib import urllib.request
import urllib.error
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
@ -43,8 +44,10 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
class CustomInstallCommand(install): class CustomInstallCommand(install):
def run(self): def run(self):
raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel # Determine the version numbers that will be used to determine the correct wheel
_, cuda_version = get_cuda_bare_metal_version() _, cuda_version = get_cuda_bare_metal_version(CUDA_HOME)
torch_version = torch.__version__ torch_version = torch.__version__
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform() platform_name = get_platform()
@ -64,7 +67,10 @@ class CustomInstallCommand(install):
except urllib.error.HTTPError: except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...") print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source # If the wheel could not be downloaded, build from source
install.run(self) #install.run(self)
raise ValueError
raise ValueError
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):