[FT] rotary_cos/sin should have batch_size dimension
This commit is contained in:
parent
d2f4324f4c
commit
2800efc71f
@ -1065,14 +1065,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
if (params.rotary_cos == nullptr) {
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
} else {
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
||||
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (params.rotary_cos == nullptr) {
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
} else {
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
||||
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1112,7 +1116,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||
} else {
|
||||
mmha::apply_rotary_embedding(
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
|
||||
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||
}
|
||||
|
||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||
@ -1123,7 +1129,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
||||
} else {
|
||||
mmha::apply_rotary_embedding(
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_cos, params.rotary_sin);
|
||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
|
||||
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||
}
|
||||
}
|
||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||
|
||||
@ -160,15 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
||||
if (rotary_cos_.has_value()) {
|
||||
auto rotary_cos = rotary_cos_.value();
|
||||
CHECK_DEVICE(rotary_cos);
|
||||
rotary_embedding_dim = rotary_cos.size(0) * 2;
|
||||
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
|
||||
rotary_embedding_dim = rotary_cos.size(-1) * 2;
|
||||
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
|
||||
CHECK_CONTIGUOUS(rotary_cos);
|
||||
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(rotary_sin_.has_value());
|
||||
auto rotary_sin = rotary_sin_.value();
|
||||
CHECK_DEVICE(rotary_sin);
|
||||
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
|
||||
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
|
||||
CHECK_CONTIGUOUS(rotary_sin);
|
||||
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user