flash-attention/hopper
Cameron Shinn 3cea2fb6ee
Add ArchTag to pre/postprocess bwd kernels (#1180)
* Add ArchTag to pre/postprocess bwd kernels

* Type-dependent CC check for bwd pre/postprocess

* Fix CC >= 90 for bwd postprocess

---------

Co-authored-by: Cameron Shinn <cshinn@nvidia.com>
2024-08-28 00:20:47 -07:00
..
__init__.py FA3 initial code release 2024-07-11 09:53:36 -07:00
benchmark_attn.py bwd benchmark + small fixes (#1129) 2024-08-05 21:27:52 -07:00
benchmark_flash_attention_fp8.py FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
block_info.h FA3 initial code release 2024-07-11 09:53:36 -07:00
epilogue_bwd_sm90_tma.hpp [FA3] Bwd 2024-08-01 01:57:06 -07:00
epilogue_fwd_sm90_tma.hpp Fix out-of-bound writes for var-seq-len zero-length KVs 2024-08-16 01:17:40 -07:00
flash_api.cpp FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
flash_attn_interface.py FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
flash_bwd_hdim64_bf16_sm90.cu [FA3] Bwd 2024-08-01 01:57:06 -07:00
flash_bwd_hdim64_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_bwd_hdim96_bf16_sm90.cu [FA3] Bwd 2024-08-01 01:57:06 -07:00
flash_bwd_hdim96_fp16_sm90.cu [FA3] Bwd 2024-08-01 01:57:06 -07:00
flash_bwd_hdim128_bf16_sm90.cu [FA3] Bwd 2024-08-01 01:57:06 -07:00
flash_bwd_hdim128_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_bwd_hdim256_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_bwd_kernel.h Remove struct : cute::aligned_struct to avoid error with gcc 12 2024-08-02 00:59:35 -07:00
flash_bwd_launch_template.h Add ArchTag to pre/postprocess bwd kernels (#1180) 2024-08-28 00:20:47 -07:00
flash_bwd_postprocess_kernel.h Add ArchTag to pre/postprocess bwd kernels (#1180) 2024-08-28 00:20:47 -07:00
flash_bwd_preprocess_kernel.h Add ArchTag to pre/postprocess bwd kernels (#1180) 2024-08-28 00:20:47 -07:00
flash_fwd_hdim64_bf16_sm90.cu [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
flash_fwd_hdim64_e4m3_sm90.cu Fp8 kernel with "in-kernel" transpose of V in producer (#1100) 2024-07-30 14:14:14 -07:00
flash_fwd_hdim64_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_fwd_hdim128_bf16_sm90.cu [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
flash_fwd_hdim128_e4m3_sm90.cu Fp8 kernel with "in-kernel" transpose of V in producer (#1100) 2024-07-30 14:14:14 -07:00
flash_fwd_hdim128_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_fwd_hdim256_bf16_sm90.cu [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
flash_fwd_hdim256_e4m3_sm90.cu Fp8 kernel with "in-kernel" transpose of V in producer (#1100) 2024-07-30 14:14:14 -07:00
flash_fwd_hdim256_fp16_sm90.cu FA3 initial code release 2024-07-11 09:53:36 -07:00
flash_fwd_kernel.h FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
flash_fwd_launch_template.h FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
flash.h FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
kernel_traits.h FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
mainloop_bwd_sm90_tma_gmma_ws.hpp [FA3] Bwd 2024-08-01 01:57:06 -07:00
mainloop_fwd_sm90_tma_gmma_ws.hpp FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
named_barrier.hpp [FA3] Bwd 2024-08-01 01:57:06 -07:00
seq_len.h Add var-seq-len to FA3 fp16 / bf16 fwd (#1072) 2024-07-22 21:32:41 -07:00
setup.py [FA3] Bwd 2024-08-01 01:57:06 -07:00
softmax.h FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
static_switch.h Add var-seq-len to FA3 fp16 / bf16 fwd (#1072) 2024-07-22 21:32:41 -07:00
test_flash_attn.py FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
tile_scheduler_bwd.hpp [FA3] Bwd 2024-08-01 01:57:06 -07:00
tile_scheduler.hpp Fp8 kernel with "in-kernel" transpose of V in producer (#1100) 2024-07-30 14:14:14 -07:00
utils.h [FA3] Bwd 2024-08-01 01:57:06 -07:00