Commit Graph

68 Commits

Author SHA1 Message Date
Garrett Byrd
16025d8cc9
Clearer install instructions for CUDA and ROCm backends (#1147)
* Update README.md

* Update README.md

* Update README.md (Added missing line in AMD ROCm Support)
2024-08-13 22:21:56 -07:00
Tri Dao
5ca83a9c71 Clean up softcapping bwd a bit 2024-07-23 00:13:54 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Tri Dao
7f67966cc7 FA3 initial code release 2024-07-11 09:53:36 -07:00
Tri Dao
da11d1b853 Bump v2.6.0 2024-07-10 21:34:58 -07:00
Tri Dao
320fb59487 Update citation 2024-05-26 16:09:03 -07:00
Tri Dao
2406f28805 Enable headdim 256 backward on consumer GPUs (Ampere, Ada) 2024-02-21 15:56:19 -08:00
Tao He
204c3c6d1b
Fixes an error in comment (#785)
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-01-23 12:38:29 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Erich Schubert
99ea4baa1d
Typo in README (#760) 2024-01-08 09:59:00 -08:00
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
50d144c906 Mention Alibi in README 2023-12-21 23:48:16 -08:00
Tri Dao
7f31e7c16a Bump to v2.3.2 2023-10-08 17:21:29 -07:00
Tri Dao
5a83425442 Change constexpr int to constexpr static int 2023-10-08 16:26:33 -07:00
Tri Dao
3a9fe7b0fa Add change log 2023-10-05 14:19:08 -07:00
Tri Dao
aa4fd2d166 Clarify that Windows is not supported right now 2023-10-05 14:00:45 -07:00
Tri Dao
0c04943fa2 Require CUDA 11.6+, clean up setup.py 2023-09-03 21:24:56 -07:00
Jeffrey Quesnelle
1d817a8ffc
fix citation in README (#501) 2023-08-29 11:15:33 -07:00
Tri Dao
45ba93cd96 Add newlines to README 2023-08-24 23:54:13 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
Tri Dao
d30f2e1cd5 Bump to v2.0.4 2023-08-01 09:01:07 -07:00
Ian Timmis
cbf982afa5
README syntax highlighting (#365)
* README syntax highlighting

Adds syntax highlighting to README

* Update README.md
2023-07-23 00:21:30 -07:00
Tri Dao
d1a3b52f17 Add instruction about limiting number of ninja jobs 2023-07-17 23:17:47 -07:00
Tri Dao
b4cc152e97 Make sure dout is contiguous 2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Tri Dao
ce68305c84 Update installation instruction 2023-05-25 16:52:52 -07:00
Tri Dao
f0c40b7ddb Recommend Nvidia's Pytorch container 2023-05-19 09:41:14 -07:00
Tri Dao
40a25c8ee7 Update roadmap 2023-05-17 08:32:26 -07:00
Anthony Hu
d63cfc3551 Use pyproject.toml to specify build dependencies 2023-04-27 11:51:52 +01:00
Tri Dao
74af023316 Bump version to 1.0.0 2023-04-11 23:32:35 -07:00
Tri Dao
1b18f1b7a1 Support H100 2023-03-15 14:59:02 -07:00
Tri Dao
f28d61cb2a Update README on requirements (nvcc and Pytorch) 2023-03-13 12:48:07 -07:00
Tri Dao
57ee618170
Merge pull request #94 from calebthomas259/main
Add a simple tutorial to README.md
2023-02-14 19:03:08 -08:00
Tri Dao
2dc2a19589 Update roadmap 2023-02-09 12:21:30 -08:00
Caleb Thomas
c9a649805b Add a simple tutorial to README.md 2022-12-27 14:13:59 +08:00
Tri Dao
4a6eaa9f27 Update configs, add results 2022-11-29 04:46:43 -08:00
Tri Dao
45bcf37b97 [Docs] Capitalize the bibtex citation 2022-11-22 02:12:22 -08:00
Tri Dao
4040256b5e Update pip install instructions, bump to 0.2 2022-11-15 14:10:48 -08:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
3dda4f76de Update README 2022-11-13 16:52:40 -08:00
Tri Dao
46fd2a20b2 Support all head dims that are multiples of 8, up to 128 2022-10-24 16:04:21 -07:00
Tri Dao
2ed471ecc4 Add tests for numerical error 2022-07-22 17:54:09 -04:00
Tri Dao
42f54d8840 Edit mention of Triton implementation
Phil Tillet suggests calling it "experimental".
2022-07-11 17:02:29 -07:00
Tri Dao
4577151ff8 Link to Triton implementation 2022-07-11 16:01:43 -07:00
Tri Dao
d1fc80a3bb Link to IEEE Spectrum article on MLPerf 2022-07-10 12:11:46 -07:00
Tri Dao
1bbebccc0a Edit README to mention bf16 support 2022-07-09 23:34:29 -07:00
Tri Dao
de19de7ab1 Implement for bf16 2022-07-09 23:31:56 -07:00
Tri Dao
6c3a8c65af Implement cross attention 2022-07-03 17:48:12 -07:00
Tri Dao
450b64fe44 Add README section on issues 2022-06-27 13:50:16 -07:00
Dan Fu
765741c1ee More explanation 2022-06-14 11:55:14 -07:00