15 lines
400 B
Python
15 lines
400 B
Python
# coding=utf-8
|
|
import torch
|
|
import transformers
|
|
import torch.nn as nn
|
|
|
|
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM
|
|
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
|
|
from transformers import AutoModel, AutoConfig
|
|
|
|
|
|
class ModelLoader:
|
|
|
|
def __init__(self, model_path: str):
|
|
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|