[FT] rotary_cos/sin should have batch_size dimension
This commit is contained in:
parent
d2f4324f4c
commit
2800efc71f
@ -1065,14 +1065,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|||||||
if (params.rotary_cos == nullptr) {
|
if (params.rotary_cos == nullptr) {
|
||||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||||
} else {
|
} else {
|
||||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
||||||
|
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||||
|
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if (params.rotary_cos == nullptr) {
|
if (params.rotary_cos == nullptr) {
|
||||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||||
} else {
|
} else {
|
||||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
|
||||||
|
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||||
|
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1112,7 +1116,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|||||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
|
||||||
} else {
|
} else {
|
||||||
mmha::apply_rotary_embedding(
|
mmha::apply_rotary_embedding(
|
||||||
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_cos, params.rotary_sin);
|
q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
|
||||||
|
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||||
|
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
|
||||||
@ -1123,7 +1129,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|||||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
|
||||||
} else {
|
} else {
|
||||||
mmha::apply_rotary_embedding(
|
mmha::apply_rotary_embedding(
|
||||||
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_cos, params.rotary_sin);
|
q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
|
||||||
|
params.rotary_cos + bi * params.rotary_embedding_dim / 2,
|
||||||
|
params.rotary_sin + bi * params.rotary_embedding_dim / 2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
|
||||||
|
|||||||
@ -160,15 +160,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
|||||||
if (rotary_cos_.has_value()) {
|
if (rotary_cos_.has_value()) {
|
||||||
auto rotary_cos = rotary_cos_.value();
|
auto rotary_cos = rotary_cos_.value();
|
||||||
CHECK_DEVICE(rotary_cos);
|
CHECK_DEVICE(rotary_cos);
|
||||||
rotary_embedding_dim = rotary_cos.size(0) * 2;
|
rotary_embedding_dim = rotary_cos.size(-1) * 2;
|
||||||
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
|
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
|
||||||
CHECK_CONTIGUOUS(rotary_cos);
|
CHECK_CONTIGUOUS(rotary_cos);
|
||||||
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
|
TORCH_CHECK(rotary_cos.scalar_type() == input_type);
|
||||||
|
|
||||||
TORCH_CHECK(rotary_sin_.has_value());
|
TORCH_CHECK(rotary_sin_.has_value());
|
||||||
auto rotary_sin = rotary_sin_.value();
|
auto rotary_sin = rotary_sin_.value();
|
||||||
CHECK_DEVICE(rotary_sin);
|
CHECK_DEVICE(rotary_sin);
|
||||||
CHECK_SHAPE(rotary_cos, rotary_embedding_dim / 2);
|
CHECK_SHAPE(rotary_cos, batch_size, rotary_embedding_dim / 2);
|
||||||
CHECK_CONTIGUOUS(rotary_sin);
|
CHECK_CONTIGUOUS(rotary_sin);
|
||||||
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
|
TORCH_CHECK(rotary_sin.scalar_type() == input_type);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user