flash-attention/flash_attn
JDKWangGuan 0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() (#898)
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
2024-06-30 22:40:03 -07:00
..
layers Fix typo in RotaryEmbedding forward output type (#666) 2023-11-09 11:43:02 -08:00
losses [CrossEntropy] Change ignored_index -> ignore_index 2024-04-26 10:50:41 -07:00
models Fix KeyError handling for non-existing key in state_dict.pop() (#898) 2024-06-30 22:40:03 -07:00
modules fix: cast the alibi slopes to torch.float32 (#846) 2024-03-15 00:49:40 -07:00
ops remove an unused import (#960) 2024-05-23 11:12:31 -07:00
utils Update citation 2024-05-26 16:09:03 -07:00
__init__.py Limit to MAX_JOBS=1 with CUDA 12.2 2024-05-26 15:35:49 -07:00
bert_padding.py Updated missing docstrings for args and returns in bert_padding.py (#795) 2024-01-27 09:16:25 -08:00
flash_attn_interface.py Support unpadded LSE layout (#970) 2024-06-27 02:38:13 -07:00
flash_attn_triton_og.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
flash_attn_triton.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
flash_blocksparse_attention.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
flash_blocksparse_attn_interface.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
fused_softmax.py Run isort and black on python files 2023-08-18 14:22:11 -07:00
pyproject.toml Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00