2022-10-24 14:04:16 +08:00
|
|
|
# Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron
|
|
|
|
|
# We add the case where seqlen = 4k and seqlen = 8k
|
|
|
|
|
import os
|
|
|
|
|
import subprocess
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from setuptools import setup
|
|
|
|
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cuda_bare_metal_version(cuda_dir):
|
|
|
|
|
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
|
|
|
|
output = raw_output.split()
|
|
|
|
|
release_idx = output.index("release") + 1
|
|
|
|
|
release = output[release_idx].split(".")
|
|
|
|
|
bare_metal_major = release[0]
|
|
|
|
|
bare_metal_minor = release[1][0]
|
|
|
|
|
|
|
|
|
|
return raw_output, bare_metal_major, bare_metal_minor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_nvcc_threads(nvcc_extra_args):
|
|
|
|
|
_, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
|
|
|
|
|
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
|
2024-03-14 11:46:57 +08:00
|
|
|
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
|
|
|
|
|
return nvcc_extra_args + ["--threads", nvcc_threads]
|
2022-10-24 14:04:16 +08:00
|
|
|
return nvcc_extra_args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cc_flag = []
|
|
|
|
|
cc_flag.append("-gencode")
|
|
|
|
|
cc_flag.append("arch=compute_70,code=sm_70")
|
|
|
|
|
cc_flag.append("-gencode")
|
|
|
|
|
cc_flag.append("arch=compute_80,code=sm_80")
|
|
|
|
|
|
|
|
|
|
setup(
|
|
|
|
|
name='fused_softmax_lib',
|
|
|
|
|
ext_modules=[
|
|
|
|
|
CUDAExtension(
|
|
|
|
|
name='fused_softmax_lib',
|
|
|
|
|
sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'],
|
|
|
|
|
extra_compile_args={
|
|
|
|
|
'cxx': ['-O3',],
|
|
|
|
|
'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag)
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
cmdclass={
|
|
|
|
|
'build_ext': BuildExtension
|
|
|
|
|
})
|