[Gen, FT] Use tlength instead of params.timestep for rotary
This commit is contained in:
parent
a01d1213d7
commit
f266fc7262
@ -1082,10 +1082,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
|
||||
if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
|
||||
if (handle_kv) {
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len);
|
||||
}
|
||||
else {
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len);
|
||||
}
|
||||
}
|
||||
else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
|
||||
@ -1120,13 +1120,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
|
||||
mmha::apply_rotary_embedding(
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len);
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len);
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
else {
|
||||
mmha::apply_rotary_embedding(
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep);
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength);
|
||||
}
|
||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user