From 48bc6eacd61b4b57bbd250057655d52f7068ba2f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 30 May 2023 13:38:34 -0700 Subject: [PATCH] [Gen] Add rotary base as an argument to FT attention kernel --- .../decoder_masked_multihead_attention.h | 1 + ...er_masked_multihead_attention_template.hpp | 8 +- ...decoder_masked_multihead_attention_utils.h | 108 +++++++++--------- csrc/ft_attention/ft_attention.cpp | 7 +- flash_attn/layers/rotary.py | 3 +- flash_attn/models/gpt.py | 4 +- flash_attn/modules/mha.py | 22 ++-- tests/models/test_gpt_generation.py | 3 +- 8 files changed, 84 insertions(+), 72 deletions(-) diff --git a/csrc/ft_attention/decoder_masked_multihead_attention.h b/csrc/ft_attention/decoder_masked_multihead_attention.h index c25c87f..590b02c 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention.h +++ b/csrc/ft_attention/decoder_masked_multihead_attention.h @@ -84,6 +84,7 @@ struct Multihead_attention_params_base { // The per-head latent space reserved for rotary embeddings. int rotary_embedding_dim = 0; bool neox_rotary_style = false; + float rotary_base = 0.0f; // The maximum length of input sentences. int max_input_length = 0; // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index a58d601..8da5929 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params 0 && !params.neox_rotary_style) { if (handle_kv) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len); + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len); + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); } } else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { @@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); } -inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (2 * tid >= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } -inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; } Float4_& q_ = *reinterpret_cast(&q); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q_.x = rotary_embedding_transform(q_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q_.y = rotary_embedding_transform(q_.y, coef1); } -inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; @@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int Float4_& q_ = *reinterpret_cast(&q); Float4_& k_ = *reinterpret_cast(&k); - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q_.x = rotary_embedding_transform(q_.x, coef0); k_.x = rotary_embedding_transform(k_.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q_.y = rotary_embedding_transform(q_.y, coef1); k_.y = rotary_embedding_transform(k_.y, coef1); } -inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (2 * tid >= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); } -inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (2 * tid >= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } -inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); } -inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); } -inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (8 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); q.w = rotary_embedding_transform(q.w, coef3); } -inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (8 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } #ifdef ENABLE_BF16 -inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (2 * tid >= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); } inline __device__ void -apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step) +apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (2 * tid >= rot_embed_dim) { return; } - const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step); + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); q = rotary_embedding_transform(q, coef); k = rotary_embedding_transform(k, coef); } -inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); } -inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (4 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); } -inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (8 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); q.z = rotary_embedding_transform(q.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); q.w = rotary_embedding_transform(q.w, coef3); } -inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step) +inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) { if (8 * tid >= rot_embed_dim) { return; } - const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step); + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); q.x = rotary_embedding_transform(q.x, coef0); k.x = rotary_embedding_transform(k.x, coef0); - const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); q.y = rotary_embedding_transform(q.y, coef1); k.y = rotary_embedding_transform(k.y, coef1); - const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); q.z = rotary_embedding_transform(q.z, coef2); k.z = rotary_embedding_transform(k.z, coef2); - const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); q.w = rotary_embedding_transform(q.w, coef3); k.w = rotary_embedding_transform(k.w, coef3); } diff --git a/csrc/ft_attention/ft_attention.cpp b/csrc/ft_attention/ft_attention.cpp index 03b7199..de19a4d 100644 --- a/csrc/ft_attention/ft_attention.cpp +++ b/csrc/ft_attention/ft_attention.cpp @@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params ¶ms, const size_t headdim, const int timestep, const int rotary_embedding_dim, + const float rotary_base, const bool neox_rotary_style, const int qkv_batch_stride, T *q_ptr, @@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params ¶ms, params.num_heads = nheads; params.hidden_size_per_head = headdim; params.rotary_embedding_dim = rotary_embedding_dim; + params.rotary_base = rotary_base; params.neox_rotary_style = neox_rotary_style; params.timestep = timestep; params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); @@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, c10::optional length_per_sample_, const int timestep, const int rotary_embedding_dim = 0, + const float rotary_base = 10000.0f, const bool neox_rotary_style=true) { CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); int batch_size = v_cache.size(0); @@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q, using DataType = typename SATypeConverter::Type; Masked_multihead_attention_params params; set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, - rotary_embedding_dim, neox_rotary_style, q.stride(0), + rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), reinterpret_cast(q.data_ptr()), reinterpret_cast(k.data_ptr()), reinterpret_cast(v.data_ptr()), @@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_query_attention", &single_query_attention, "Attention with a single query", py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, - py::arg("neox_rotary_style")=true); + py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); } diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index 437b3a7..44ceab5 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module): Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ - def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None): + def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None): """ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of 1st half and 2nd half (GPT-NeoX style). """ super().__init__() + self.base = float(base) # Generate and save the inverse frequency buffer (non trainable) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 111641a..88df046 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) out_proj_bias = getattr(config, 'out_proj_bias', True) rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) + rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0) rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) use_flash_attn = getattr(config, 'use_flash_attn', False) @@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, dropout=config.attn_pdrop, softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, - rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base, + rotary_emb_scale_base=rotary_emb_scale_base, rotary_emb_interleaved=rotary_emb_interleaved, use_flash_attn=use_flash_attn, **serial_kwargs, **parallel_kwargs, **factory_kwargs) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 098d982..9d68668 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -350,9 +350,9 @@ class MHA(nn.Module): def __init__(self, embed_dim, num_heads, cross_attn=False, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, - rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, - fused_bias_fc=False, use_flash_attn=False, return_residual=False, - checkpointing=False, device=None, dtype=None) -> None: + rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, + rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False, + return_residual=False, checkpointing=False, device=None, dtype=None) -> None: """ return_residual: whether to return the input x along with the output. This is for performance reason: for post-norm architecture, returning the input allows us @@ -377,7 +377,8 @@ class MHA(nn.Module): if self.rotary_emb_dim > 0: assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert RotaryEmbedding is not None, 'rotary_emb is not installed' - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base, + scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device) if fused_bias_fc and FusedDense is None: @@ -511,11 +512,12 @@ class MHA(nn.Module): k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] if inference_params.lengths_per_sample is not None else None) + rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], lengths_per_sample, inference_params.sequence_len_offset, - self.rotary_emb_dim, + self.rotary_emb_dim, rotary_emb_base, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True ) @@ -555,8 +557,8 @@ class ParallelMHA(nn.Module): def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True, dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, - rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False, - use_flash_attn=False, checkpointing=False, + rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, + rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False, sequence_parallel=True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() @@ -573,7 +575,8 @@ class ParallelMHA(nn.Module): if self.rotary_emb_dim > 0: assert RotaryEmbedding is not None, 'rotary_emb is not installed' - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base, + scale_base=rotary_emb_scale_base, interleaved=rotary_emb_interleaved, device=device) if ColumnParallelLinear is None or RowParallelLinear is None: @@ -631,11 +634,12 @@ class ParallelMHA(nn.Module): k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx] lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] if inference_params.lengths_per_sample is not None else None) + rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 context = ft_attention.single_query_attention( *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), k_cache[batch_start:batch_end], v_cache[batch_start:batch_end], lengths_per_sample, inference_params.sequence_len_offset, - self.rotary_emb_dim, inference_params.sequence_len_offset, + self.rotary_emb_dim, rotary_emb_base, # neox_rotary_style (not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True ) diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index 652aca0..38f5afa 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): config = GPT2Config.from_pretrained(model_name) if rotary: config.n_positions = 0 - config.rotary_emb_dim = 64 + config.rotary_emb_fraction = 0.5 + config.rotary_emb_base = 24000 config.residual_in_fp32 = True if optimized: config.use_flash_attn = True