[Gen] Add rotary base as an argument to FT attention kernel

This commit is contained in:
Tri Dao 2023-05-30 13:38:34 -07:00
parent 7c766b1bbc
commit 48bc6eacd6
8 changed files with 84 additions and 72 deletions

View File

@ -84,6 +84,7 @@ struct Multihead_attention_params_base {
// The per-head latent space reserved for rotary embeddings.
int rotary_embedding_dim = 0;
bool neox_rotary_style = false;
float rotary_base = 0.0f;
// The maximum length of input sentences.
int max_input_length = 0;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?

View File

@ -1061,10 +1061,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
}
else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len);
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
}
}
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
@ -1099,13 +1099,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len);
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
}
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
}

View File

@ -1272,9 +1272,9 @@ inline __device__ void zero(T& dst)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step)
inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base)
{
const float inv_freq = t_step / pow(10000.0f, zid / (float)rot_embed_dim);
const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim);
return {cos(inv_freq), sin(inv_freq)};
}
@ -1302,49 +1302,49 @@ inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162
}
#endif
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
return;
}
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
return;
}
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1);
}
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
@ -1352,166 +1352,166 @@ inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int
Float4_& q_ = *reinterpret_cast<Float4_*>(&q);
Float4_& k_ = *reinterpret_cast<Float4_*>(&k);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q_.x = rotary_embedding_transform(q_.x, coef0);
k_.x = rotary_embedding_transform(k_.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q_.y = rotary_embedding_transform(q_.y, coef1);
k_.y = rotary_embedding_transform(k_.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}
#ifdef ENABLE_BF16
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
}
inline __device__ void
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step)
apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (2 * tid >= rot_embed_dim) {
return;
}
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step);
const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base);
q = rotary_embedding_transform(q, coef);
k = rotary_embedding_transform(k, coef);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (4 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
}
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step)
inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f)
{
if (8 * tid >= rot_embed_dim) {
return;
}
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step);
const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base);
q.x = rotary_embedding_transform(q.x, coef0);
k.x = rotary_embedding_transform(k.x, coef0);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step);
const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base);
q.y = rotary_embedding_transform(q.y, coef1);
k.y = rotary_embedding_transform(k.y, coef1);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step);
const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base);
q.z = rotary_embedding_transform(q.z, coef2);
k.z = rotary_embedding_transform(k.z, coef2);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step);
const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base);
q.w = rotary_embedding_transform(q.w, coef3);
k.w = rotary_embedding_transform(k.w, coef3);
}

View File

@ -54,6 +54,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
const size_t headdim,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
const bool neox_rotary_style,
const int qkv_batch_stride,
T *q_ptr,
@ -82,6 +83,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.num_heads = nheads;
params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_base = rotary_base;
params.neox_rotary_style = neox_rotary_style;
params.timestep = timestep;
params.inv_sqrt_dh = 1.f / sqrt(float(headdim));
@ -107,6 +109,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
c10::optional<const torch::Tensor> length_per_sample_,
const int timestep,
const int rotary_embedding_dim = 0,
const float rotary_base = 10000.0f,
const bool neox_rotary_style=true) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0);
@ -144,7 +147,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, neox_rotary_style, q.stride(0),
rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()),
reinterpret_cast<DataType*>(v.data_ptr()),
@ -163,5 +166,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("neox_rotary_style")=true);
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
}

View File

@ -169,12 +169,13 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def __init__(self, dim: int, base=10000, interleaved=False, scale_base=None, device=None):
def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, device=None):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
"""
super().__init__()
self.base = float(base)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim))

View File

@ -75,6 +75,7 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias = getattr(config, 'qkv_proj_bias', True)
out_proj_bias = getattr(config, 'out_proj_bias', True)
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0)
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None)
rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False)
use_flash_attn = getattr(config, 'use_flash_attn', False)
@ -91,7 +92,8 @@ def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dt
qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias,
dropout=config.attn_pdrop,
softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx,
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base,
rotary_emb_scale_base=rotary_emb_scale_base,
rotary_emb_interleaved=rotary_emb_interleaved,
use_flash_attn=use_flash_attn,
**serial_kwargs, **parallel_kwargs, **factory_kwargs)

View File

@ -350,9 +350,9 @@ class MHA(nn.Module):
def __init__(self, embed_dim, num_heads, cross_attn=False,
qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
checkpointing=False, device=None, dtype=None) -> None:
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
"""
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
@ -377,7 +377,8 @@ class MHA(nn.Module):
if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device)
if fused_bias_fc and FusedDense is None:
@ -511,11 +512,12 @@ class MHA(nn.Module):
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim,
self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
)
@ -555,8 +557,8 @@ class ParallelMHA(nn.Module):
def __init__(self, embed_dim, num_heads, process_group, qkv_proj_bias=True, out_proj_bias=True,
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
rotary_emb_dim=0, rotary_emb_scale_base=None, rotary_emb_interleaved=False,
use_flash_attn=False, checkpointing=False,
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
sequence_parallel=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
@ -573,7 +575,8 @@ class ParallelMHA(nn.Module):
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
scale_base=rotary_emb_scale_base,
interleaved=rotary_emb_interleaved, device=device)
if ColumnParallelLinear is None or RowParallelLinear is None:
@ -631,11 +634,12 @@ class ParallelMHA(nn.Module):
k_cache, v_cache = inference_params.key_value_memory_dict[self.layer_idx]
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
if inference_params.lengths_per_sample is not None else None)
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
context = ft_attention.single_query_attention(
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
k_cache[batch_start:batch_end], v_cache[batch_start:batch_end],
lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim, inference_params.sequence_len_offset,
self.rotary_emb_dim, rotary_emb_base,
# neox_rotary_style
(not self.rotary_emb.interleaved) if self.rotary_emb_dim > 0 else True
)

View File

@ -36,7 +36,8 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.rotary_emb_fraction = 0.5
config.rotary_emb_base = 24000
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True