torch_ext/fi/load_model.py
2024-11-18 19:54:12 +08:00

45 lines
1.5 KiB
Python

# coding=utf-8
import os
import torch
import transformers
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoConfig
from transformers.models import qwen2, gemma2, llama, gemma
decode_layers = {
"gemma": gemma.modeling_gemma.GemmaDecoderLayer,
"gemma2": gemma2.modeling_gemma2.Gemma2DecoderLayer,
"qwen2": qwen2.modeling_qwen2.Qwen2DecoderLayer,
}
MODELS = {
"gemma": gemma.GemmaForCausalLM,
"gemma2": gemma2.Gemma2ForCausalLM,
"llama": llama.LlamaForCausalLM,
"qwen2": qwen2.Qwen2ForCausalLM,
}
class ModelLoader:
def __init__(self, model_path: str, pipeline_num: int = 1):
self.config_path = os.path.join(model_path, "config.json")
self.model_config = AutoConfig.from_pretrained(self.config_path)
hidden_layers = self.model_config.get("num_hidden_layers", -1)
if hidden_layers == -1:
raise ValueError("do not has such parameter")
self.hidden_layers = hidden_layers
self.pipeline_num = pipeline_num
self.model = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True
)
self.model_type = self.model_config["model_type"]
self.per_pipeline_layers = self.hidden_layers // self.pipeline_num
module_list = None
for x in self.model.modules():
if isinstance(x, torch.nn.modules.container.ModuleList):
module_list = x
if module_list is None:
raise ValueError("do not have module list.")