[Gen, FT] Use tlength instead of params.timestep for rotary

This commit is contained in:
Tri Dao 2023-01-03 17:46:55 -08:00
parent a01d1213d7
commit f266fc7262

View File

@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
if (handle_kv) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len);
}
else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len);
}
}
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len);
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len);
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
}
else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep);
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
}
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
}