[Gen] Add rotary base as an argument to FT attention kernel
This commit is contained in:
parent
7c766b1bbc
commit
48bc6eacd6
@ -84,6 +84,7 @@ struct Multihead_attention_params_base {
|
||||
// The per-head latent space reserved for rotary embeddings.
|
||||
int rotary_embedding_dim = 0;
|
||||
bool neox_rotary_style = false;
|
||||
float rotary_base = 0.0f;
|
||||
// The maximum length of input sentences.
|
||||
int max_input_length = 0;
|
||||
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
|
||||
|
||||
@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
|
||||
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
|
||||
if (handle_kv) {
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len);
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
}
|
||||
else {
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len);
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
}
|
||||
}
|
||||
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
||||
@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len);
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
else {
|
||||
mmha::apply_rotary_embedding(
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
||||
}
|
||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
@ -1272,9 +1272,9 @@ inline __device__ void zero(T& dst)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step)
|
||||
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
|
||||
{
|
||||
const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim);
|
||||
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
|
||||
return {cos(inv_freq), sin(inv_freq)};
|
||||
}
|
||||
|
||||
@ -1302,49 +1302,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162
|
||||
}
|
||||
#endif
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
|
||||
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
|
||||
|
||||
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
|
||||
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q_.x = rotary_embedding_transform(q_.x, coef0);
|
||||
k_.x = rotary_embedding_transform(k_.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q_.y = rotary_embedding_transform(q_.y, coef1);
|
||||
k_.y = rotary_embedding_transform(k_.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
k.z = rotary_embedding_transform(k.z, coef2);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
k.w = rotary_embedding_transform(k.w, coef3);
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
}
|
||||
|
||||
inline __device__ void
|
||||
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step)
|
||||
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (2 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
|
||||
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
|
||||
q = rotary_embedding_transform(q, coef);
|
||||
k = rotary_embedding_transform(k, coef);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (4 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
}
|
||||
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step)
|
||||
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
|
||||
{
|
||||
if (8 * tid >= rot_embed_dim) {
|
||||
return;
|
||||
}
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
|
||||
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
|
||||
q.x = rotary_embedding_transform(q.x, coef0);
|
||||
k.x = rotary_embedding_transform(k.x, coef0);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
|
||||
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
|
||||
q.y = rotary_embedding_transform(q.y, coef1);
|
||||
k.y = rotary_embedding_transform(k.y, coef1);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
|
||||
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
|
||||
q.z = rotary_embedding_transform(q.z, coef2);
|
||||
k.z = rotary_embedding_transform(k.z, coef2);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
|
||||
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
|
||||
q.w = rotary_embedding_transform(q.w, coef3);
|
||||
k.w = rotary_embedding_transform(k.w, coef3);
|
||||
}
|
||||
|
||||
@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
|
||||
const size_t headdim,
|
||||
const int timestep,
|
||||
const int rotary_embedding_dim,
|
||||
const float rotary_base,
|
||||
const bool neox_rotary_style,
|
||||
const int qkv_batch_stride,
|
||||
T *q_ptr,
|
||||
@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
|
||||
params.num_heads = nheads;
|
||||
params.hidden_size_per_head = headdim;
|
||||
params.rotary_embedding_dim = rotary_embedding_dim;
|
||||
params.rotary_base = rotary_base;
|
||||
params.neox_rotary_style = neox_rotary_style;
|
||||
params.timestep = timestep;
|
||||
params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
|
||||
@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
||||
c10::optional<const torch::Tensor> length_per_sample_,
|
||||
const int timestep,
|
||||
const int rotary_embedding_dim = 0,
|
||||
const float rotary_base = 10000.0f,
|
||||
const bool neox_rotary_style=true) {
|
||||
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
|
||||
int batch_size = v_cache.size(0);
|
||||
@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
||||
using DataType = typename SATypeConverter<scalar_t>::Type;
|
||||
Masked_multihead_attention_params<DataType> params;
|
||||
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
|
||||
rotary_embedding_dim, neox_rotary_style, q.stride(0),
|
||||
rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
|
||||
reinterpret_cast<DataType*>(q.data_ptr()),
|
||||
reinterpret_cast<DataType*>(k.data_ptr()),
|
||||
reinterpret_cast<DataType*>(v.data_ptr()),
|
||||
@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
|
||||
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
|
||||
py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
|
||||
py::arg("neox_rotary_style")=true);
|
||||
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
|
||||
}
|
||||
|
||||
@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module):
|
||||
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None):
|
||||
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None):
|
||||
"""
|
||||
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
|
||||
of 1st half and 2nd half (GPT-NeoX style).
|
||||
"""
|
||||
super().__init__()
|
||||
self.base = float(base)
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
|
||||
dtype=torch.float32) / dim))
|
||||
|
||||
@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
|
||||
out_proj_bias = getattr(config, 'out_proj_bias', True)
|
||||
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
|
||||
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
|
||||
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
|
||||
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
|
||||
use_flash_attn = getattr(config, 'use_flash_attn', False)
|
||||
@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
|
||||
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
|
||||
dropout=config.attn_pdrop,
|
||||
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
|
||||
rotary_emb_scale_base=rotary_emb_scale_base,
|
||||
rotary_emb_interleaved=rotary_emb_interleaved,
|
||||
use_flash_attn=use_flash_attn,
|
||||
**serial_kwargs, **parallel_kwargs, **factory_kwargs)
|
||||
|
||||
@ -350,9 +350,9 @@ class MHA(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, cross_attn=False,
|
||||
qkv_proj_bias=True, out_proj_bias=True,
|
||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
|
||||
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
|
||||
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
|
||||
checkpointing=False, device=None, dtype=None) -> None:
|
||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
|
||||
return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
|
||||
"""
|
||||
return_residual: whether to return the input x along with the output. This is for
|
||||
performance reason: for post-norm architecture, returning the input allows us
|
||||
@ -377,7 +377,8 @@ class MHA(nn.Module):
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
|
||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
interleaved=rotary_emb_interleaved, device=device)
|
||||
|
||||
if fused_bias_fc and FusedDense is None:
|
||||
@ -511,11 +512,12 @@ class MHA(nn.Module):
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
||||
if inference_params.lengths_per_sample is not None else None)
|
||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
|
||||
lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim,
|
||||
self.rotary_emb_dim, rotary_emb_base,
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
)
|
||||
@ -555,8 +557,8 @@ class ParallelMHA(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True,
|
||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
|
||||
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
|
||||
use_flash_attn=False, checkpointing=False,
|
||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
||||
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
|
||||
sequence_parallel=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
@ -573,7 +575,8 @@ class ParallelMHA(nn.Module):
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
interleaved=rotary_emb_interleaved, device=device)
|
||||
|
||||
if ColumnParallelLinear is None or RowParallelLinear is None:
|
||||
@ -631,11 +634,12 @@ class ParallelMHA(nn.Module):
|
||||
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
|
||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
||||
if inference_params.lengths_per_sample is not None else None)
|
||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||
context = ft_attention.single_query_attention(
|
||||
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
|
||||
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
|
||||
lengths_per_sample, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim, inference_params.sequence_len_offset,
|
||||
self.rotary_emb_dim, rotary_emb_base,
|
||||
# neox_rotary_style
|
||||
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
|
||||
)
|
||||
|
||||
@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
|
||||
config = GPT2Config.from_pretrained(model_name)
|
||||
if rotary:
|
||||
config.n_positions = 0
|
||||
config.rotary_emb_dim = 64
|
||||
config.rotary_emb_fraction = 0.5
|
||||
config.rotary_emb_base = 24000
|
||||
config.residual_in_fp32 = True
|
||||
if optimized:
|
||||
config.use_flash_attn = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user