diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a8e8349 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "csrc/flash_attn/cutlass"] + path = csrc/flash_attn/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/csrc/flash_attn/cutlass b/csrc/flash_attn/cutlass new file mode 160000 index 0000000..319a389 --- /dev/null +++ b/csrc/flash_attn/cutlass @@ -0,0 +1 @@ +Subproject commit 319a389f42b776fae5701afcb943fc03be5b5c25 diff --git a/setup.py b/setup.py index eff5a46..a661fa6 100644 --- a/setup.py +++ b/setup.py @@ -111,6 +111,7 @@ if int(bare_metal_major) < 11: cc_flag.append("-gencode") cc_flag.append("arch=compute_80,code=sm_80") +subprocess.run(["git", "submodule", "update", "--init", "csrc/flash_attn/cutlass"]) ext_modules.append( CUDAExtension( name="flash_attn_cuda", @@ -141,6 +142,7 @@ ext_modules.append( include_dirs=[ Path(this_dir) / 'csrc' / 'flash_attn', Path(this_dir) / 'csrc' / 'flash_attn' / 'src', + Path(this_dir) / 'csrc' / 'flash_attn' / 'cutlass' / 'include', ], ) )