[FT] rotary_cos/sin should have batch_size dimension

This commit is contained in:
Tri Dao 2023-07-06 15:33:33 -07:00
parent d2f4324f4c
commit 2800efc71f
2 changed files with 15 additions and 7 deletions

View File

@ -1065,14 +1065,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
}
}
else {
if (params.rotary_cos == nullptr) {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
}
}
}
@ -1112,7 +1116,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
}
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
@ -1123,7 +1129,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
} else {
mmha::apply_rotary_embedding(
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_cos, params.rotary_sin);
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
}
}
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);

View File

@ -160,15 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
if (rotary_cos_.has_value()) {
auto rotary_cos = rotary_cos_.value();
CHECK_DEVICE(rotary_cos);
rotary_embedding_dim = rotary_cos.size(0) * 2;
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
rotary_embedding_dim = rotary_cos.size(-1) * 2;
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_cos);
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
TORCH_CHECK(rotary_sin_.has_value());
auto rotary_sin = rotary_sin_.value();
CHECK_DEVICE(rotary_sin);
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
CHECK_CONTIGUOUS(rotary_sin);
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
}