From 187c2a06358d421e3d350fac8ff8714013c1f1cd Mon Sep 17 00:00:00 2001 From: Yuchao Dai <3407450+icyblade@users.noreply.github.com> Date: Fri, 22 Sep 2023 02:48:23 +0800 Subject: [PATCH] Fix E1136 (#563) --- flash_attn/models/gpt.py | 3 ++- flash_attn/models/llama.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index e822028..b2403dc 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -6,6 +6,7 @@ import re from collections import OrderedDict, namedtuple from collections.abc import Sequence from functools import partial +from typing import Dict, List import torch import torch.nn as nn @@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): return state_dict -def combine_state_dicts_tp(state_dicts: list[dict[str, torch.Tensor]], config: GPT2Config): +def combine_state_dicts_tp(state_dicts: List[Dict[str, torch.Tensor]], config: GPT2Config): """Convert the list of sharded state_dict of a GPT model with tensor parallel to the state_dict of a standard GPT model. diff --git a/flash_attn/models/llama.py b/flash_attn/models/llama.py index 2841efd..3bfb51d 100644 --- a/flash_attn/models/llama.py +++ b/flash_attn/models/llama.py @@ -6,7 +6,7 @@ import os import re from collections import OrderedDict from pathlib import Path -from typing import Union +from typing import Dict, List, Union import torch import torch.nn.functional as F @@ -17,8 +17,8 @@ from einops import rearrange def remap_state_dict_meta_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in Meta format to standard GPT format. This function modifies state_dict in place. @@ -113,8 +113,8 @@ def remap_state_dict_meta_llama( def remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in Hugging Face format to standard GPT format. This function modifies state_dict in place. @@ -217,8 +217,8 @@ def remap_state_dict_hf_llama( def inv_remap_state_dict_hf_llama( - state_dict: dict[str, torch.Tensor], config: GPT2Config -) -> dict[str, torch.Tensor]: + state_dict: Dict[str, torch.Tensor], config: GPT2Config +) -> Dict[str, torch.Tensor]: """Convert the state_dict in standard GPT format to Hugging Face format. This function is meant to be the inverse of remap_state_dict_hf_llama, up to a @@ -382,7 +382,7 @@ def config_from_checkpoint( def state_dicts_from_checkpoint( checkpoint_path: Union[str, os.PathLike], model_name: str -) -> list[dict]: +) -> List[dict]: # Need to sort, otherwise we mess up the ordering and the weights are wrong return [ torch.load(path, map_location="cpu")