Add Cutlass as submodule
This commit is contained in:
parent
ad6c694bb3
commit
512c98ee05
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
[submodule "csrc/flash_attn/cutlass"]
|
||||||
|
path = csrc/flash_attn/cutlass
|
||||||
|
url = https://github.com/NVIDIA/cutlass.git
|
||||||
1
csrc/flash_attn/cutlass
Submodule
1
csrc/flash_attn/cutlass
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25
|
||||||
2
setup.py
2
setup.py
@ -111,6 +111,7 @@ if int(bare_metal_major) < 11:
|
|||||||
cc_flag.append("-gencode")
|
cc_flag.append("-gencode")
|
||||||
cc_flag.append("arch=compute_80,code=sm_80")
|
cc_flag.append("arch=compute_80,code=sm_80")
|
||||||
|
|
||||||
|
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
|
||||||
ext_modules.append(
|
ext_modules.append(
|
||||||
CUDAExtension(
|
CUDAExtension(
|
||||||
name="flash_attn_cuda",
|
name="flash_attn_cuda",
|
||||||
@ -141,6 +142,7 @@ ext_modules.append(
|
|||||||
include_dirs=[
|
include_dirs=[
|
||||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||||
|
Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user