Add Cutlass as submodule
This commit is contained in:
parent
ad6c694bb3
commit
512c98ee05
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
[submodule "csrc/flash_attn/cutlass"]
|
||||
path = csrc/flash_attn/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
1
csrc/flash_attn/cutlass
Submodule
1
csrc/flash_attn/cutlass
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25
|
||||
2
setup.py
2
setup.py
@ -111,6 +111,7 @@ if int(bare_metal_major) < 11:
|
||||
cc_flag.append("-gencode")
|
||||
cc_flag.append("arch=compute_80,code=sm_80")
|
||||
|
||||
subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"])
|
||||
ext_modules.append(
|
||||
CUDAExtension(
|
||||
name="flash_attn_cuda",
|
||||
@ -141,6 +142,7 @@ ext_modules.append(
|
||||
include_dirs=[
|
||||
Path(this_dir) / 'csrc' / 'flash_attn',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
|
||||
Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include',
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user