flash-attention/flash_attn/models
JDKWangGuan 0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() (#898)
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
2024-06-30 22:40:03 -07:00
..
__init__.py Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
baichuan.py Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
bert.py [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
bigcode.py Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00
btlm.py Implement BTLM model 2023-12-24 20:35:12 -08:00
falcon.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
gpt_neox.py [Gen] Remove minor dead code 2023-12-19 22:57:39 -08:00
gpt.py Fix KeyError handling for non-existing key in state_dict.pop() (#898) 2024-06-30 22:40:03 -07:00
gptj.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
llama.py Fix E1136 (#563) 2023-09-21 11:48:23 -07:00
opt.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
vit.py [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00