Commit Graph

717 Commits

Author SHA1 Message Date
Tri Dao
c33de664a1 Fix import in test 2024-08-01 02:14:25 -07:00
Tri Dao
bafe253042 [FA3] Bwd 2024-08-01 01:57:06 -07:00
Ying Zhang
abffb0f98c
Merge pull request #1115 from ipiszy/bench
Add cudnn benchmark for var-len
2024-07-31 22:42:06 -07:00
Ying Zhang
c7f20a2d31 add cudnn benchmark for var-len 2024-07-31 22:33:29 -07:00
jayhshah
5018ac6ac5
Fp8 kernel with "in-kernel" transpose of V in producer (#1100)
* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* base version

* restructure pipelines, add special fp8 epilogue

* add variants

* add fp8 causal and modify dynamic tile scheduler

* better causal schedule

* maintain two schedules for non causal and causal

* removing macros

* fix regression

* clean up unneeded methods and variants

* fix mistake with NumProducerThreads

* use seqlen traits

* add fp8 .cu files and benchmark script

* fix merge issue

* fix merge issue

* fix merge issue

* remove duplicate code

* fix regression with varseqlen

* move varseqlen init in constexpr

* fix test script

* more constexpr on varseqlen and add max offset

* add back test cases
2024-07-30 14:14:14 -07:00
Tri Dao
c4b9015d74 Add benchmark_gemm.py 2024-07-27 11:13:18 -07:00
Tri Dao
418d677192 Bump to v2.6.3 2024-07-25 01:31:28 -07:00
Tri Dao
65205d350e [CI] Compile for pytorch 2.4.0 2024-07-25 01:30:34 -07:00
Tri Dao
3aae9c18c1 Revert "Changes For FP8 (#1075)"
This reverts commit 1899c970c8.
2024-07-25 01:28:44 -07:00
ganeshcolfax
1899c970c8
Changes For FP8 (#1075)
* adding files for fp8 changes.

* removed contiguous check.

* enable all tests except odd-seq-lengths, where it crashes now.

* undid clang formatting.

* change to correct tile size for headdim=128.

* fixed odd-seq-len-k.

* minor formatting.

* minor reformatting.

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
2024-07-23 13:51:14 -07:00
Tri Dao
59594f2a67 Bump to v2.6.2 2024-07-23 02:30:05 -07:00
Tri Dao
299563626f Fix test with alibi and cache_leftpad 2024-07-23 02:04:15 -07:00
Tri Dao
4488acee8d [CI] Compile with torch 2.4.0.dev20240527 2024-07-23 01:33:32 -07:00
Tri Dao
65f723bb9a Split bwd into more .cu files to speed up compilation 2024-07-23 01:32:09 -07:00
Tri Dao
5ca83a9c71 Clean up softcapping bwd a bit 2024-07-23 00:13:54 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
Driss Guessous
1c275eb070
Fix ima for split-kv kernel (#1085) 2024-07-22 22:19:46 -07:00
janEbert
3c4053b75c
Make FA3 externally importable (#1053)
Library name to import is `flash_attn_interface`, which matches the
test.
2024-07-22 21:34:56 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Ying Zhang
dfe1a59e4b
Add var-seq-len to FA3 fp16 / bf16 fwd (#1072)
* fwd var-seq-len

* fixes

* benchmark

* fixes

---------

Co-authored-by: Tri Dao <tridao@users.noreply.github.com>
2024-07-22 21:32:41 -07:00
Cameron Shinn
cb516f855b
Remove torchlib dependency from cpp files (#1083) 2024-07-22 16:47:09 -07:00
Phil Wang
5f1ae4a34b
backwards for softcapping (#1033)
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
2024-07-21 23:25:46 -07:00
youkaichao
ef3e358a25
remove lambda (#1056) 2024-07-21 23:24:38 -07:00
Jorge António
4df62e1440
catch typo (#1058) 2024-07-21 23:24:15 -07:00
Tri Dao
74b0761ff7 [FA3] BF16 forward 2024-07-14 23:39:46 -07:00
Tri Dao
898dd4bbf2 Pass seqused_k to _flash_attn_varlen_forward 2024-07-13 00:08:27 -07:00
Tri Dao
7ef24848cf Add FA3 image 2024-07-11 09:54:05 -07:00
Tri Dao
7f67966cc7 FA3 initial code release 2024-07-11 09:53:36 -07:00
Tri Dao
b4a9dd6c9c Temporarily switch to cutlass fork for more shapes 2024-07-11 09:29:21 -07:00
Tri Dao
7551202cb2 Bump to v2.6.1 2024-07-11 08:28:32 -07:00
Tri Dao
844912dca0 [CI] Switch from CUDA 12.2 to 12.3 2024-07-11 08:20:09 -07:00
Tri Dao
40e534a7f6 Implement cache_leftpad 2024-07-11 08:17:15 -07:00
Tri Dao
116b05f9b0 [CI] Compile with pytorch 2.4.0.dev20240514 2024-07-11 02:53:30 -07:00
Tri Dao
da11d1b853 Bump v2.6.0 2024-07-10 21:34:58 -07:00
Tri Dao
d0787acc16 Relax dropout_fraction test 2024-07-10 11:49:40 -07:00
Tri Dao
dca6d89da4 Don't support softcap and dropout at the same time
These tests are failing so I'm just disabling this case for now
2024-07-10 11:23:12 -07:00
Tri Dao
81e01efd4b More typo fixes 2024-07-10 10:19:17 -07:00
Tri Dao
72e27c6320 Fix typo with softcapping 2024-07-10 00:33:52 -07:00
Tri Dao
3d41db3e2c Only test backward if there's no softcapping 2024-07-10 00:27:45 -07:00
Tri Dao
908511b2b6 Split into more .cu files to speed up compilation 2024-07-10 00:24:04 -07:00
Tri Dao
1d536d7de5 Minor cleanup of softcapping 2024-07-09 22:57:03 -07:00
Tri Dao
beb2bf2a32 Drop support for pytorch 1.12, 1.13, and python 3.7 2024-07-09 22:13:15 -07:00
Phil Wang
f4628b43ec
missing commas and backwards return arguments (#1032)
* missing commas

* another fix
2024-07-09 10:56:29 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
Jianwei Dong
4e8d60069f
Add the return_softmax_lse parameter to the flash_attn_with_kvcache function to allow returning the logsumexp of the attention scores. (#989) 2024-07-08 08:29:40 -07:00
muoshuosha
6df7e0a02e
Fix the varlen deterministic test (#1023)
Co-authored-by: moshuosha <moshuosha@qq.com>
2024-07-03 11:07:57 -07:00
66RING
9486635c92
Fix typos of comments about shape. (#837) 2024-06-30 22:40:59 -07:00
JDKWangGuan
0d810cfb73
Fix KeyError handling for non-existing key in state_dict.pop() (#898)
Update handling for KeyError in state_dict.pop() for non-existing keys.
Changed state_dict.pop(f"h.{d}.attn.bias") to state_dict.pop(f"h.{d}.attn.bias", None) to prevent KeyError exceptions.


The following code can re-produce the issue
```
from transformers import AutoTokenizer, GPT2Model, GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel, GPTModel

# >>> transformers.__version__
# '4.38.2'

model_path = 'gpt2'
output_model_path = 'gpt2_model'
config = GPT2Config.from_pretrained(model_path, output_hidden_states=True)
model = GPT2Model.from_pretrained(model_path, from_tf=False, config=config)
'''
model fine-tuning here
'''
# dump the fine-tuned model
model.save_pretrained(output_model_path)

# load the fine-tuned model
config = GPT2Config.from_pretrained(output_model_path, output_hidden_states=True)
model = GPTModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'
model = GPTLMHeadModel.from_pretrained(output_model_path, config=config, strict=True)  # failed due to KeyError: 'h.0.attn.bias'

```
2024-06-30 22:40:03 -07:00
cao lei
6a2a16e994
fix typo (#974) 2024-06-30 22:39:39 -07:00
Nicolas Patry
5bf201966a
Fixing argument checking when using seqlenq_ngroups_swapped. (#976)
When user send `out` as a parameter of the function
`seqlenq_ngroups_swapped` with parameters that trigger,
the CHECK_SHAPE is incorrect (since q shape is modified.)
2024-06-30 22:39:22 -07:00