Commit Graph

717 Commits

Author SHA1 Message Date
Neil Tenenholtz
7153673c1a
Fix swiglu backwards return type (#1337) 2024-11-15 16:23:40 -08:00
Tri Dao
641db759ab [CI] Pytorch 2.5.1 does not support python 3.8 2024-11-12 20:02:13 -08:00
Tri Dao
7435839e3d Update README for FA3 2024-11-12 20:01:07 -08:00
Tri Dao
241c682c9f [CI] Switch back to CUDA 12.4 2024-11-12 14:24:27 -08:00
Tri Dao
c555642172 Bump to v2.7.0 2024-11-12 14:11:44 -08:00
Tri Dao
6ffeb572b1 [CI] Still use CUDA 12.3 but pull the right pytorch version 2024-11-12 14:04:30 -08:00
Ethan Steinberg
42f2b8be34
Use CUDA 12.4 in the build system (#1326)
The current build system uses 12.3, but that causes builds to fail
since there are no official PyTorch releases for 12.3.
2024-11-12 13:40:38 -08:00
Tri Dao
2f6c633179 Drop support for Pytorch 2.0 2024-11-12 11:58:16 -08:00
rocking
88d1657a14
[AMD ROCm] Fix KVcache bug and improve performance (#1328)
* update ck

* update ck

* update ck again

* update ck

* use pointer as seed and offset

* update CK

* Remove useless "else"

* Fix page-attn block table read out-of-bound

---------

Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
2024-11-12 11:32:11 -08:00
Kai Londenberg
284e2c6e5b
Make FA3 paged attention ready for upgrade to Cutlass 3.6 (#1331) 2024-11-12 11:31:37 -08:00
Kai Londenberg
b443207c1f
Paged Attention support for FA3 (#1268) 2024-11-09 17:05:01 -08:00
NanoCode012
f0bf3ed9ab
Feat: Add support for PyTorch 2.5 in workflows (#1284)
* Feat: Add support for PyTorch 2.5 in workflows

* fix: update to 2.5.1
2024-11-07 00:37:56 -08:00
Son Nguyen
478ee666cc
Make namespace comment consistent (#1305)
Co-authored-by: Sony Nguyen <son.nguyen@bytedance.com>
2024-10-30 22:32:49 -07:00
milesvant
c1d146cbd5
Fix copy-paste error in hopper tests (#1279) 2024-10-15 13:54:40 -07:00
jayhshah
a5a75274bc
FA3 kvcache + split kv + gqa parallelization (#1236) 2024-10-15 00:21:22 -07:00
Tri Dao
bedf877467 [CrossEntropy] Fix where labels address not aligned to 16 bytes 2024-10-05 02:03:10 -07:00
rocking
53a4f34163
Hotfix due to change of upstream api (#1239) 2024-09-20 12:45:25 -07:00
hlky
8476986721
Fix FAv3 compilation with MSVC (#1240) 2024-09-20 12:44:59 -07:00
Ying Zhang
9cafd4ae14
Merge pull request #1233 from Dao-AILab/ipiszy/local_attn
Add local attention in Hopper FAv3
2024-09-19 23:14:45 -07:00
Ying Zhang
1c9717d699 address comments 2024-09-19 22:50:59 -07:00
Zhihao Shen
30e1ef0f79
minify torch.torch.int32 to torch.int32 (#1237) 2024-09-18 00:32:59 -07:00
Antoni Viros
83e41b3ca4
Add custom ops for compatibility with PT Compile (#1139)
* Add custom ops for compatibility with PT Compile

* Add support for varlen functions too

* Add version checks for pytorch API

* Fix PT compile interfaces so it works e2e

* Make sure PT < 2.4 runs fine

* Fix python mistake

* Fix all the autograd magic issues

* typo on head_dim

* Fix deterministic test failures, remove unneeded detaches()

* remove test requires_grad

* Resolve all the pytorch versioning issues

* C++ and python refactor to improve padding management for torch.compile()

* Add improvements suggested by @anijain2305
2024-09-17 19:49:26 -07:00
Ying Zhang
be6c1b98c4 small fixes 2024-09-16 16:13:00 -07:00
Ying Zhang
dff976a84a fixes 2024-09-16 15:44:44 -07:00
Ying Zhang
7b4e68e04f hopper local attention 2024-09-16 14:59:22 -07:00
Ying Zhang
af314d4006
Merge pull request #1182 from ipiszy/used_q
Add seqused_q in fwd / bwd and seqused_k in bwd in hopper FA.
2024-09-16 14:57:19 -07:00
Ying Zhang
8cbc8a042f small fixes 2024-09-16 14:54:39 -07:00
Ying Zhang
cdbbe844b1 minor changes to unpad_input test util func 2024-09-16 14:24:11 -07:00
Ying Zhang
db80387343 Add seqused_q in fwd / bwd and seqused_k in bwd. 2024-09-16 14:24:11 -07:00
rocking
e2182cc21d
Support page kvcache in AMD ROCm (#1198)
* Integrate ck branch of ck_tile/fa_bwd_opt

* Assume dq and q share the same stride

* update ck

* Integrate more stride of dq_acc

* Revert fwd dropout

* Fix paremeter order

* Integrate ck with more stride

* update the limit of hdim of bwd

* Check argument

* Add test_flash_attn_causal

* Support unpad lse

* Add  test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic

* Fix stride and Kn0

* Fix CK sync issue

* Fix typo

* Update CK for changing of fmha_fwd_args

* Add kvcache tmp

* Add kvcache

* Fix comment

* Sync behavior with ck

* Update CK to develop

* remove large test case

* Add kvcache test

* Fix page_block_size in arg

* Minor fix

* Fix stride error

* Update seqlen of kvcache before splitkv

* Fix compile error

* Fix bug of hdim is not 8x

* Fit ck arg

* support adaptive num_splits

* add more tests

* Refine test tolerance

* update CK

* Move override_num_splits_if_necessary into cpp

* update ck

* Update ck

* Support different flag for different version of hip

* remove coerce-illegal, becasue this is not required in FA

* Update ck to fix xcratch memory

* Add coerce-illegal in some version

* Add compile flag for rtn rounding

* remove redundant init

* Using env var to switch rounding mode

* update ck
2024-09-15 23:17:28 -07:00
Tri Dao
cc1690d9d6 [Rotary] Add test for rotary when qkv are packed an there's GQA 2024-09-12 22:35:20 -07:00
Tri Dao
8c20cfef49 [Rotary] Support qkv block layout from GQA 2024-09-11 10:39:58 -07:00
Charlene Yang
bdf733be55
Add q, k, v descales to FA3 interface (#1210)
* add descale_q/k/v for fp8 fwd

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix .apply args

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
2024-09-09 21:53:52 -07:00
Tri Dao
c7f32a8409 [CrossEntropy] Support precomputed LSE 2024-09-08 09:24:43 -07:00
juejuezi
e371bea04f
feat: change minimal supported CUDA version to 11.7 (#1206) 2024-09-05 10:34:35 -07:00
Cameron Shinn
3cea2fb6ee
Add ArchTag to pre/postprocess bwd kernels (#1180)
* Add ArchTag to pre/postprocess bwd kernels

* Type-dependent CC check for bwd pre/postprocess

* Fix CC >= 90 for bwd postprocess

---------

Co-authored-by: Cameron Shinn <cshinn@nvidia.com>
2024-08-28 00:20:47 -07:00
jayhshah
c92ca63268
FA3 FP8 qkv descales + restore max offset for h128 causal + added sync for producer WG (#1173) 2024-08-25 12:18:04 -07:00
Tri Dao
d79f9b41a8 [CrossEntropy] Use online softmax to simplify implementation 2024-08-24 17:40:39 -07:00
Jay Shah
32792d37ec add missing if condition for key_padding_mask in test_util.py 2024-08-19 11:17:17 -07:00
Ying Zhang
28e7f4ddbd
Merge pull request #1155 from ipiszy/fix
Fix out-of-bound writes for var-seq-len zero-length KVs
2024-08-17 13:34:06 -07:00
Ying Zhang
53537da422 add a unittest 2024-08-17 13:23:50 -07:00
Ying Zhang
a3a257c71d Fix out-of-bound writes for var-seq-len zero-length KVs 2024-08-16 01:17:40 -07:00
Tri Dao
bcd918f275 [LayerNorm] Add option to write result to out and residual_out 2024-08-15 14:43:47 -07:00
Tri Dao
bd82d6c6eb Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
This reverts commit 800401847e.
2024-08-15 12:02:39 -07:00
Tri Dao
800401847e [LayerNorm] Don't store x + residual if we don't need gradients 2024-08-15 11:08:46 -07:00
Garrett Byrd
16025d8cc9
Clearer install instructions for CUDA and ROCm backends (#1147)
* Update README.md

* Update README.md

* Update README.md (Added missing line in AMD ROCm Support)
2024-08-13 22:21:56 -07:00
Ying Zhang
3669b25206
bwd benchmark + small fixes (#1129) 2024-08-05 21:27:52 -07:00
Tri Dao
5d5bfbb619 Remove contiguous checks 2024-08-05 14:47:07 -07:00
SueJane
3f1b4d38e7
Fix: check the type of max_seqlen_k instead of checking max_seqlen twice (#1127) 2024-08-05 08:59:23 -07:00
Tri Dao
3f6ff1c1c5 Remove struct : cute::aligned_struct to avoid error with gcc 12 2024-08-02 00:59:35 -07:00