diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp index 480b2be..35c1a5b 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp +++ b/csrc/ft_attention/decoder_masked_multihead_attention_template.hpp @@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params 0 && !params.neox_rotary_style) { if (handle_kv) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len); } else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len); } } else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { @@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params