torch_ext/fi/load_model.py
2024-11-16 19:26:54 +08:00

15 lines
400 B
Python

# coding=utf-8
import torch
import transformers
import torch.nn as nn
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers import AutoModel, AutoConfig
class ModelLoader:
def __init__(self, model_path: str):
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)