9 lines
246 B
Python
9 lines
246 B
Python
import torch
|
|
|
|
from transformers.utils import WEIGHTS_NAME
|
|
from transformers.utils.hub import cached_file
|
|
|
|
|
|
def state_dict_from_pretrained(model_name, device=None):
|
|
return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
|