[Misc] Sort the list of embedding models (#10037)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
c4cacbaa7f
commit
82bfc38d07
@ -94,33 +94,23 @@ _TEXT_GENERATION_MODELS = {
|
||||
_EMBEDDING_MODELS = {
|
||||
# [Text-only]
|
||||
"BertModel": ("bert", "BertEmbeddingModel"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
|
||||
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
|
||||
**{
|
||||
# Multiple models share the same architecture, so we include them all
|
||||
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
|
||||
if arch == "LlamaForCausalLM"
|
||||
},
|
||||
"MistralModel": ("llama", "LlamaEmbeddingModel"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": (
|
||||
"qwen2_cls", "Qwen2ForSequenceClassification"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
|
||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
|
||||
}
|
||||
|
||||
def add_embedding_models(base_models, embedding_models):
|
||||
with_pooler_method_models = {}
|
||||
embedding_models_name = embedding_models.keys()
|
||||
for name, (path, arch) in base_models.items():
|
||||
if arch in embedding_models_name:
|
||||
with_pooler_method_models[name] = (path, arch)
|
||||
return with_pooler_method_models
|
||||
|
||||
_EMBEDDING_MODELS = {
|
||||
**add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
|
||||
**_EMBEDDING_MODELS,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
# [Decoder-only]
|
||||
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user