# coding=utf-8 import os import torch import transformers import torch.nn as nn from transformers import AutoModelForCausalLM, AutoConfig from transformers.models import qwen2, gemma2, llama, gemma decode_layers = { "gemma": gemma.modeling_gemma.GemmaDecoderLayer, "gemma2": gemma2.modeling_gemma2.Gemma2DecoderLayer, "qwen2": qwen2.modeling_qwen2.Qwen2DecoderLayer, } MODELS = { "gemma": gemma.GemmaForCausalLM, "gemma2": gemma2.Gemma2ForCausalLM, "llama": llama.LlamaForCausalLM, "qwen2": qwen2.Qwen2ForCausalLM, } class ModelLoader: def __init__(self, model_path: str, pipeline_num: int = 1): self.config_path = os.path.join(model_path, "config.json") self.model_config = AutoConfig.from_pretrained(self.config_path) hidden_layers = self.model_config.get("num_hidden_layers", -1) if hidden_layers == -1: raise ValueError("do not has such parameter") self.hidden_layers = hidden_layers self.pipeline_num = pipeline_num self.model = AutoModelForCausalLM.from_pretrained( model_path, trust_remote_code=True ) self.model_type = self.model_config["model_type"] self.per_pipeline_layers = self.hidden_layers // self.pipeline_num module_list = None for x in self.model.modules(): if isinstance(x, torch.nn.modules.container.ModuleList): module_list = x if module_list is None: raise ValueError("do not have module list.")