diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 42d6322..5ba68e9 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -16,8 +16,7 @@ #include "softmax.h" #include "mask.h" #include "dropout.h" - -#include "alibi.h" +#include "rotary.h" namespace flash { @@ -222,16 +221,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Prologue - Tensor tQrQ = make_fragment_like(tQgQ); // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } - // // Copy rmem to smem - // // copy(tQrQ, tQsQ); - // flash::cp_async_wait<0>(); - // __syncthreads(); // // if (cute::thread(1, 0)) { print(tQsQ); } // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); // // if (cute::thread0()) { print(sQNoSwizzle); } @@ -744,7 +738,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons } // Read Q from gmem to smem, optionally apply rotary embedding. - Tensor tQrQ = make_fragment_like(tQgQ); if (!Append_KV || params.rotary_dim == 0) { // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, diff --git a/csrc/flash_attn/src/rotary.h b/csrc/flash_attn/src/rotary.h new file mode 100644 index 0000000..dc2825b --- /dev/null +++ b/csrc/flash_attn/src/rotary.h @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, + Tensor &D, + Tensor const &Cos, + Tensor const &Sin, + Tensor const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type(rS(_, m, k)); + Tensor S_other_fp32 = convert_type(rS_other); + Tensor cos_fp32 = convert_type(rCos(_, m, k)); + Tensor sin_fp32 = convert_type(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 4d644ea..2b45e87 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -391,137 +391,4 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace flash