2022-12-28 01:49:59 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from transformers.utils import WEIGHTS_NAME
|
|
|
|
|
from transformers.utils.hub import cached_file
|
|
|
|
|
|
|
|
|
|
|
2023-01-16 03:34:27 +08:00
|
|
|
def state_dict_from_pretrained(model_name, device=None, dtype=None):
|
|
|
|
|
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|
|
|
|
|
if dtype is not None:
|
|
|
|
|
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
|
|
|
|
|
return state_dict
|