Commit Graph

316 Commits

Author SHA1 Message Date
Tri Dao
81e01efd4b More typo fixes 2024-07-10 10:19:17 -07:00
Tri Dao
72e27c6320 Fix typo with softcapping 2024-07-10 00:33:52 -07:00
Phil Wang
f4628b43ec
missing commas and backwards return arguments (#1032)
* missing commas

* another fix
2024-07-09 10:56:29 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
Jianwei Dong
4e8d60069f
Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989) 2024-07-08 08:29:40 -07:00
JDKWangGuan
0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() (#898)
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
2024-06-30 22:40:03 -07:00
Grigory Sizov
f816dee63c
Support unpadded LSE layout (#970)
* Support unpadded LSE layout.

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>

* Cleanup

* Fix unpadded LSE on split-kv path

* Fix formatting and comments

* Fix inline vs forceinline

---------

Co-authored-by: Xinfeng Xie <xfxie.ceca@gmail.com>
Co-authored-by: Jianyu Huang <hjyahead@gmail.com>
2024-06-27 02:38:13 -07:00
Tri Dao
320fb59487 Update citation 2024-05-26 16:09:03 -07:00
Tri Dao
e2e4333c95 Limit to MAX_JOBS=1 with CUDA 12.2 2024-05-26 15:35:49 -07:00
Tri Dao
ce73503578 Bump to 2.5.9 2024-05-26 14:02:11 -07:00
lancerts
22339db185
remove an unused import (#960) 2024-05-23 11:12:31 -07:00
Tri Dao
9a11f440d3 Bump to v2.5.8 2024-04-26 10:54:52 -07:00
Tri Dao
ec6d22143b [CrossEntropy] Change ignored_index -> ignore_index 2024-04-26 10:50:41 -07:00
Tri Dao
85881f547f Bump to v2.5.7 2024-04-07 20:13:05 -07:00
Ivan Komarov
f692b98d80
Fix spurious re-compilations of rotary_kernel (#911)
All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
2024-04-05 13:40:41 -07:00
Tri Dao
36587c01cb [LayerNorm] Update layer_norm_linear 2024-03-18 23:15:33 -07:00
Markus Krimmel
6bbc532388
fix: cast the alibi slopes to torch.float32 (#846) 2024-03-15 00:49:40 -07:00
Grigory Sizov
2a15840f09
Enable paged attention in varlen forward (#831)
* Enable paged attention in varlen forward

* Format + fix padding
2024-03-15 00:48:19 -07:00
Tri Dao
6c9e60de56 Bump to v2.5.6 2024-03-01 22:09:56 -08:00
Tri Dao
87a1277653 Bump to v2.5.5 2024-02-21 15:58:23 -08:00
Tri Dao
43950dda45 Bump to v2.5.4 2024-02-20 16:30:16 -08:00
Tri Dao
5cdabc2809 Bump to v2.5.3 2024-02-10 01:06:27 -08:00
Tri Dao
a190df011c Add window_size option to ParallelMHA 2024-02-10 01:02:14 -08:00
Tri Dao
61a7772479 Bump to v2.5.2 2024-01-31 02:44:24 -08:00
Tri Dao
ef0ed10622 Add window_size option to MHA and GPT 2024-01-31 02:42:23 -08:00
Tri Dao
dc72d960a7 [CI] Install torch 2.3 using index 2024-01-30 14:32:29 -08:00
Tri Dao
daf37a9d8a Bump to v2.5.1 2024-01-29 21:03:38 -08:00
Avelina9X
c94cd09744
Updated missing docstrings for args and returns in bert_padding.py (#795)
* Updated docstrings of bert_padding.py

Added docstrings for missing arguments in the unpad and pad methods.

* Update bert_padding.py

Fixed spelling mistakes
2024-01-27 09:16:25 -08:00
Tao He
204c3c6d1b
Fixes an error in comment (#785)
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-01-23 12:38:29 -08:00
Tri Dao
197f2083a2 Bump to v2.5.0 2024-01-22 23:40:10 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
bdcae547c7 [LayerNorm] Don't exit early in the backward pass (fix #781) 2024-01-22 22:40:06 -08:00
Tri Dao
e43a4ceaab [CI] Fix CUDA 12.2.2 compilation 2024-01-21 17:23:39 -08:00
Tri Dao
f9d7376126 Bump to v2.4.3 2024-01-21 17:14:37 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss (#768) 2024-01-21 15:23:41 -08:00
Tri Dao
a7b66ae25a Simplify writing softmax to gmem 2024-01-13 00:25:04 -08:00
Tri Dao
c9861a032d [LayerNorm] Initialize mean and rstd tensor using x.device 2024-01-09 16:30:31 -08:00
Tri Dao
abbc131173 [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
Tri Dao
f5b308e258 [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2 [LayerNorm] Implement parallel layer norm in Triton 2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5 [LayerNorm] Implement rowscale in Triton layernorm 2024-01-04 01:07:03 -08:00
jiaxingli
386e391117
Fix: implement deterministic backward in mha (#748)
* fix deterministic

* fix deterministic
2024-01-02 18:13:56 -08:00
Tri Dao
1a2c3e8c25 Bump to v2.4.2 2023-12-25 16:28:57 -08:00
Tri Dao
73df3be7d5 Add test for BTLM init 2023-12-25 15:16:27 -08:00
Tri Dao
7ffba9a501 Implement BTLM model 2023-12-24 20:35:12 -08:00
Tri Dao
2e29dacf0c Implement muParam 2023-12-24 20:34:48 -08:00
Tri Dao
3f7d5786ba Pass alibi slopes to flash_attn_with_kvcache during generation 2023-12-24 20:31:59 -08:00
Tri Dao
f844852485 Bump to v2.4.1 2023-12-23 21:00:39 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
2c7d7b7396 Implement norm head for Baichuan2 2023-12-22 16:55:40 -08:00