From 3e9414f1c37ab033544b47e9bec4f7103c6192f1 Mon Sep 17 00:00:00 2001 From: ljss <31004720+beginlner@users.noreply.github.com> Date: Thu, 28 Mar 2024 10:11:45 +0800 Subject: [PATCH] Minor fix in compute_attn_1rowblock_splitkv (#900) --- csrc/flash_attn/src/flash_fwd_kernel.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index bd29d56..ab9f367 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -491,8 +491,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons using GmemTiledCopyO = std::conditional_t< !Split, - typename Kernel_traits::GmemTiledCopyOaccum, - typename Kernel_traits::GmemTiledCopyO + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum >; using ElementO = std::conditional_t;