Commit Graph

75 Commits

Author SHA1 Message Date
Zhihao Shen
30e1ef0f79
minify torch.torch.int32 to torch.int32 (#1237) 2024-09-18 00:32:59 -07:00
Ying Zhang
cdbbe844b1 minor changes to unpad_input test util func 2024-09-16 14:24:11 -07:00
JDKWangGuan
0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() (#898)
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
2024-06-30 22:40:03 -07:00
Tri Dao
ef0ed10622 Add window_size option to MHA and GPT 2024-01-31 02:42:23 -08:00
Tri Dao
abbc131173 [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
Tri Dao
73df3be7d5 Add test for BTLM init 2023-12-25 15:16:27 -08:00
Tri Dao
7ffba9a501 Implement BTLM model 2023-12-24 20:35:12 -08:00
Tri Dao
2e29dacf0c Implement muParam 2023-12-24 20:34:48 -08:00
Tri Dao
3f7d5786ba Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
Tri Dao
2c7d7b7396 Implement norm head for Baichuan2 2023-12-22 16:55:40 -08:00
Tri Dao
c3b2196652 Add Alibi to MHA, test with Baichuan-13B 2023-12-21 22:49:55 -08:00
Tri Dao
0a146185d6 [Gen] Remove minor dead code 2023-12-19 22:57:39 -08:00
Yuchao Dai
187c2a0635
Fix E1136 (#563) 2023-09-21 11:48:23 -07:00
Tri Dao
0705d2718d [Llama] Fix some tests, add tests for Llama 2 and CodeLlama 2023-09-20 23:36:46 -07:00
Kevin Hu
42832575d4
Fix Llama GQA/MQA (#546)
* Fix llama MQA

* Fix permute shape

* Update llama.py
2023-09-19 22:15:59 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Kevin Hu
07005806ff
Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT (#527) 2023-09-09 01:44:21 -07:00
Tri Dao
798858f9f1 Fix test_baichuan 2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72 [Gen] Add back num_last_tokens in gpt.py 2023-09-03 20:44:40 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained (#479) 2023-08-24 16:31:16 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q

* add comment.
2023-08-20 14:57:34 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70 [ViT] Run black on vit.py 2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d [GPT] Run black on gpt.py 2023-08-16 23:47:50 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training (#446)
* prelim.

* add hf convertion fn.

* mlp.

* change name.

* fix bug.

* inverse permute.

* change comment.

* revert style changes.

* fix.

* add doc.

* revert.

* enable load safe.

* fix safe load.

* fix import.

* fix typing-related lints.

* fix ckpt loading logic.

* make single gpu work.

* test with parallel.

* ckpt format.

* enable pretrained state dict.

* remove unused imports.

* remove unused.

* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
184b992dcb [GPT] Implement parallel LLaMa 2023-07-28 15:52:48 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp (#251) 2023-07-26 12:14:15 -07:00
Tri Dao
d38357dd2f [GPT] Implement Falcon 2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert (#363) 2023-07-23 00:21:45 -07:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a [Rotary] Don't store inv_freq in state_dict 2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407 [MLP] Add ParallelMLP 2023-07-22 23:45:51 -07:00
Tri Dao
48bc6eacd6 [Gen] Add rotary base as an argument to FT attention kernel 2023-05-30 13:38:34 -07:00
Federico Berto
69f5f7d0a2 [BugFix] cannot unpack non-iterable NoneType object 2023-05-07 03:07:44 +09:00
Tri Dao
a9a4b4e4f2 [LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm 2023-05-04 23:39:43 -07:00
Tri Dao
ba2fe7f378 [Gen] Move allocate_inference_cache to within the model 2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1 [GPT] Add option to only return the logit for the last token 2023-04-20 17:21:08 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f Implement GatedMlp 2023-04-18 03:37:14 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00