Raise cuda error on build

This commit is contained in:
Pierce Freeman 2023-06-02 13:20:39 -07:00
parent add4f0bc42
commit e1faefce9d
2 changed files with 10 additions and 4 deletions

View File

@ -126,7 +126,7 @@ jobs:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./${{env.wheel_name}}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*

View File

@ -10,7 +10,8 @@ from packaging.version import parse, Version
from setuptools import setup, find_packages
import subprocess
import urllib
import urllib.request
import urllib.error
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
@ -43,8 +44,10 @@ BASE_WHEEL_URL = "https://github.com/piercefreeman/flash-attention/releases/down
class CustomInstallCommand(install):
def run(self):
raise_if_cuda_home_none("flash_attn")
# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version = get_cuda_bare_metal_version()
_, cuda_version = get_cuda_bare_metal_version(CUDA_HOME)
torch_version = torch.__version__
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
@ -64,7 +67,10 @@ class CustomInstallCommand(install):
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
install.run(self)
#install.run(self)
raise ValueError
raise ValueError
def get_cuda_bare_metal_version(cuda_dir):