diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bd29d56..ab9f367 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -491,8 +491,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons using GmemTiledCopyO = std::conditional_t< !Split, - typename Kernel_traits::GmemTiledCopyOaccum, - typename Kernel_traits::GmemTiledCopyO + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum >; using ElementO = std::conditional_t;