Move rotary device functions to a separate file
This commit is contained in:
parent
3e2c827d9a
commit
395e5a0dba
@ -16,8 +16,7 @@
|
||||
#include "softmax.h"
|
||||
#include "mask.h"
|
||||
#include "dropout.h"
|
||||
|
||||
#include "alibi.h"
|
||||
#include "rotary.h"
|
||||
|
||||
namespace flash {
|
||||
|
||||
@ -222,16 +221,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Prologue
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
// // copy(tQrQ, tQsQ);
|
||||
// flash::cp_async_wait<0>();
|
||||
// __syncthreads();
|
||||
// // if (cute::thread(1, 0)) { print(tQsQ); }
|
||||
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
||||
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
||||
@ -744,7 +738,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
}
|
||||
|
||||
// Read Q from gmem to smem, optionally apply rotary embedding.
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
if (!Append_KV || params.rotary_dim == 0) {
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
|
||||
152
csrc/flash_attn/src/rotary.h
Normal file
152
csrc/flash_attn/src/rotary.h
Normal file
@ -0,0 +1,152 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
@ -391,137 +391,4 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
|
||||
Loading…
Reference in New Issue
Block a user