flash-attention/flash_attn/utils/pretrained.py

12 lines
381 B
Python
Raw Normal View History

import torch
from transformers.utils import WEIGHTS_NAME
from transformers.utils.hub import cached_file
def state_dict_from_pretrained(model_name, device=None, dtype=None):
state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
if dtype is not None:
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
return state_dict