Commit Graph

68 Commits

Author SHA1 Message Date
rocking
e2182cc21d
Support page kvcache in AMD ROCm (#1198)
* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
2024-09-15 23:17:28 -07:00
juejuezi
e371bea04f
feat: change minimal supported CUDA version to 11.7 (#1206) 2024-09-05 10:34:35 -07:00
Tri Dao
65f723bb9a Split bwd into more .cu files to speed up compilation 2024-07-23 01:32:09 -07:00
Tri Dao
751c762c9c Don't specialize for hdim 224 to speed up compilation 2024-07-23 00:13:54 -07:00
rocking
d8f104e97a
Support AMD ROCm on FlashAttention 2 (#1010)
* Support ck in fmha

* Add ck submodule

* Do not return lse if return_softmax == false

* Use receipt to speed up ck compile time

* Integrate new version of ck_tile

* Support dropout for mha_fwd()

* Add dropout to mha_varlen_fwd()

* Update ck to develop

* Extract padding function for dropout randval

* Extract randval transformation function

* Sync the code structure and coding style with FA

* Remove this line, c++ api will handle this.
Sync with test_flash_attn.py

* fix compile error

* Add mha_bwd

* Generate dropout seed and offset from user generator

* update CK

* Add mha_varlen_bwd

* Use same python as build flash-attn to generate ck kernel

* Fix bug of group mode fwd about returning softmax lse

* larger the test tollerance

* Add test_flash_attn_output() and test_flash_attn_varlen_output()

* Always fill softmax_lse

* Remove duplicate benchmark script, since we already implement mha_bwd

* Refine get value from tuple

* Use default parameter for stream_config

* unblock all platform

* Add comment

* refine the test code

* Refine naming

* Add unpack to namespace

* Do not hardcode the warp size 64

* Add more targets

* Add README

* Optimize mha_fwd if seqlen_q == 1

* Support get_wheel_url for rocm

* Detect rocm environment by pytorch's IS_HIP_EXTENSION

* update to lastest ck

* Add necessary compile flag

* Sync the api with upstream FA

---------

Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Yichen Yan <wenji.yyc@alibaba-inc.com>
Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
Co-authored-by: Yichen Yan <oraluben@outlook.com>
2024-07-22 21:34:37 -07:00
Tri Dao
844912dca0 [CI] Switch from CUDA 12.2 to 12.3 2024-07-11 08:20:09 -07:00
Tri Dao
908511b2b6 Split into more .cu files to speed up compilation 2024-07-10 00:24:04 -07:00
Tri Dao
beb2bf2a32 Drop support for pytorch 1.12, 1.13, and python 3.7 2024-07-09 22:13:15 -07:00
Nicolas Patry
8f873cc6ac
Implement softcapping. (#1025)
* Softcap v2 (fwd only).

* Some missing interface + remove overrides in tests.
2024-07-08 11:24:48 -07:00
Corey James Levinson
beb8b8ba9f
add exception to Timeout Error (#963)
When timeout connecting, you get URLError: <urlopen error timed out>, In that case, build it from source.
2024-05-26 12:33:03 -07:00
Wei Ji
9c0e9ee86d
Move packaging and ninja from install_requires to setup_requires (#937)
Set `packaging` and `ninja` as build time dependencies rather than runtime dependencies.
2024-05-06 09:45:54 -07:00
Tri Dao
2aea958f89 [CI] Compile with torch 2.3.0.dev20240207 2024-04-07 20:11:52 -07:00
Arvind Sundararajan
26c9e82743
Support ARM builds (#757) 2024-03-13 21:57:20 -07:00
Chirag Jain
50896ec574
Make nvcc threads configurable via environment variable (#885) 2024-03-13 20:46:57 -07:00
Qubitium
f45bbb4c94
Optimize compile to 1: avoid oom 2: minimize swap usage 3: avoid threads starvation when letting ninja decide how many workers to spawn or manual MAX_JOBS "guesses". Logic is to take the min value of MAX_JOBS auto-calculated by two metrics: 1: cpu cores 2: free memory. This should allow flash-attn to compile close to the most efficient manner under any consumer/server env. (#832) 2024-02-17 18:17:15 -08:00
Tri Dao
d4a7c8ffbb [CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly 2023-11-27 16:21:28 -08:00
Tri Dao
5e525a8dc8 [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 2023-10-03 22:20:30 -07:00
Tri Dao
1879e089c7 Reduce number of templates for headdim > 128 2023-09-23 22:24:30 -07:00
Tri Dao
bff3147175 Re-enable compilation for Hopper 2023-09-21 23:55:25 -07:00
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Federico Berto
fa3ddcbaaa
[Minor] add nvcc note on bare_metal_version RuntimeError (#552)
* Add nvcc note on bare_metal_version `RuntimeError`

* Run Black formatting
2023-09-18 11:48:15 -07:00
Tri Dao
799f56fa90 Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults 2023-09-17 22:15:38 -07:00
Tri Dao
bb9beb3645 Remove some unused headers 2023-09-12 12:37:10 -07:00
Tri Dao
0c04943fa2 Require CUDA 11.6+, clean up setup.py 2023-09-03 21:24:56 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Tri Dao
cbb4cf5f46 Don't need to set TORCH_CUDA_ARCH_LIST in setup.py 2023-08-18 14:18:54 -07:00
Aman Gupta Karmani
aab603af4f
fix binary wheel installation when nvcc is not available (#448) 2023-08-14 14:54:26 -07:00
Tri Dao
9c531bdc0a Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI 2023-08-14 10:03:31 -07:00
Tri Dao
2ddeaa406c Fix wheel building 2023-08-13 16:48:47 -07:00
Tri Dao
3c458cff77 Merge branch 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention into piercefreeman-feature/demo-wheels
* 'feature/demo-wheels' of https://github.com/piercefreeman/flash-attention: (25 commits)
  Install standard non-wheel package
  Remove release creation
  Build wheel on each push
  Isolate 2.0.0 & cuda12
  Clean setup.py imports
  Remove builder project
  Bump version
  Add notes to github action workflow
  Add torch dependency to final build
  Exclude cuda erroring builds
  Exclude additional disallowed matrix params
  Full version matrix
  Add CUDA 11.7
  Release is actually unsupported
  echo OS version
  Temp disable deploy
  OS version build numbers
  Restore full build matrix
  Refactor and clean of setup.py
  Strip cuda name from torch version
  ...
2023-08-13 16:03:51 -07:00
Tri Dao
1c41d2b0e5 Fix race condition in bwd (overwriting sK) 2023-08-01 09:00:10 -07:00
Tri Dao
4f285b3547 FlashAttention-2 release 2023-07-17 06:21:34 -07:00
Pierce Freeman
9af165c389 Clean setup.py imports 2023-06-07 17:27:36 -07:00
Pierce Freeman
494b2aa486 Add notes to github action workflow 2023-06-07 17:06:12 -07:00
Pierce Freeman
ea2ed88623 Refactor and clean of setup.py 2023-06-02 18:25:07 -07:00
Pierce Freeman
9fc9820a5b Strip cuda name from torch version 2023-06-02 18:25:07 -07:00
Pierce Freeman
5e4699782a Allow fallback install 2023-06-02 18:25:07 -07:00
Pierce Freeman
0e7769c813 Guessing wheel URL 2023-06-02 18:25:07 -07:00
Pierce Freeman
e1faefce9d Raise cuda error on build 2023-06-02 18:25:07 -07:00
Pierce Freeman
add4f0bc42 Scaffolding for wheel prototype 2023-06-02 18:25:07 -07:00
Max H. Gerlach
31f78a9814 Allow adding an optional local version to the package version 2023-05-19 17:27:41 +02:00
Tri Dao
eff9fe6b80 Add ninja to pyproject.toml build-system, bump to v1.0.5 2023-05-12 14:20:31 -07:00
Tri Dao
ad113948a6 [Docs] Clearer error message for bwd d > 64, bump to v1.0.4 2023-04-26 09:19:48 -07:00
Tri Dao
fbbb107848 Bump version to v1.0.3.post0 2023-04-21 13:37:23 -07:00
Tri Dao
67ef5d28df Bump version to 1.0.3 2023-04-21 12:04:53 -07:00
Tri Dao
df1344f866 Bump to v1.0.2 2023-04-15 22:19:31 -07:00
Pavel Shvets
72629ac9ba add missed module 2023-04-14 20:08:24 +03:00
Tri Dao
853ff72963 Bump version to v1.0.1, fix Cutlass version 2023-04-12 10:05:01 -07:00
Tri Dao
74af023316 Bump version to 1.0.0 2023-04-11 23:32:35 -07:00
Tri Dao
dc08ea1c33 Support H100 for other CUDA extensions 2023-03-15 16:59:27 -07:00