45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
# coding=utf-8
|
|
import os
|
|
import torch
|
|
import transformers
|
|
import torch.nn as nn
|
|
|
|
from transformers import AutoModelForCausalLM, AutoConfig
|
|
from transformers.models import qwen2, gemma2, llama, gemma
|
|
|
|
decode_layers = {
|
|
"gemma": gemma.modeling_gemma.GemmaDecoderLayer,
|
|
"gemma2": gemma2.modeling_gemma2.Gemma2DecoderLayer,
|
|
"qwen2": qwen2.modeling_qwen2.Qwen2DecoderLayer,
|
|
}
|
|
|
|
MODELS = {
|
|
"gemma": gemma.GemmaForCausalLM,
|
|
"gemma2": gemma2.Gemma2ForCausalLM,
|
|
"llama": llama.LlamaForCausalLM,
|
|
"qwen2": qwen2.Qwen2ForCausalLM,
|
|
}
|
|
|
|
|
|
class ModelLoader:
|
|
|
|
def __init__(self, model_path: str, pipeline_num: int = 1):
|
|
self.config_path = os.path.join(model_path, "config.json")
|
|
self.model_config = AutoConfig.from_pretrained(self.config_path)
|
|
hidden_layers = self.model_config.get("num_hidden_layers", -1)
|
|
if hidden_layers == -1:
|
|
raise ValueError("do not has such parameter")
|
|
self.hidden_layers = hidden_layers
|
|
self.pipeline_num = pipeline_num
|
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
model_path, trust_remote_code=True
|
|
)
|
|
self.model_type = self.model_config["model_type"]
|
|
self.per_pipeline_layers = self.hidden_layers // self.pipeline_num
|
|
module_list = None
|
|
for x in self.model.modules():
|
|
if isinstance(x, torch.nn.modules.container.ModuleList):
|
|
module_list = x
|
|
if module_list is None:
|
|
raise ValueError("do not have module list.")
|