flash-attention/flash_attn/utils/pretrained.py

9 lines
246 B
Python
Raw Normal View History

import torch
from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name, device=None):
return torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)