Commit Graph

631 Commits

Author SHA1 Message Date
Tri Dao
d9a5cb291c Fix dv = torch::empty_like(k) for mha_bwd_varlen as well 2024-02-10 01:03:00 -08:00
Tri Dao
a190df011c Add window_size option to ParallelMHA 2024-02-10 01:02:14 -08:00
Brian Hirsh
2423cca3ad
fix backward for when query and key have different contiguity (#818) 2024-02-10 01:01:27 -08:00
Grigory Sizov
4687936413
Fix Windows build (#816) 2024-02-07 17:41:53 -08:00
Tri Dao
61a7772479 Bump to v2.5.2 2024-01-31 02:44:24 -08:00
Tri Dao
6a5c053c3e [CI] Compile with torch 2.2.0 instead of 2.2.0.dev20231106 2024-01-31 02:43:12 -08:00
Tri Dao
ef0ed10622 Add window_size option to MHA and GPT 2024-01-31 02:42:23 -08:00
Tri Dao
dc72d960a7 [CI] Install torch 2.3 using index 2024-01-30 14:32:29 -08:00
Tri Dao
daf37a9d8a Bump to v2.5.1 2024-01-29 21:03:38 -08:00
Tri Dao
aa2eb8ddf2 [CI] Compile with pytorch 2.2.0.dev20231106 2024-01-29 20:49:18 -08:00
Jeremy Reizenstein
0658e320f6
Preprocessor switches to control functionality (#788)
For faster and smaller builds in some simple cases,
provide switches to allow disabling
-backward
-alibi
-uneven k
-dropout
-local attention

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2024-01-29 20:44:23 -08:00
Christian Kadner
290596c544
[CI] Build wheels for Pytorch 2.3 (dev/nightly) (#793)
* [CI] Build wheels for Pytorch 2.3 (dev/nightly)

Resolves #790

Signed-off-by: Christian Kadner <ckadner@us.ibm.com>

* update TORCH_CUDA_VERSION

Signed-off-by: Christian Kadner <ckadner@us.ibm.com>

* revert torch 2.2 back to dev20231130

Signed-off-by: Christian Kadner <ckadner@us.ibm.com>

* add link to PyTorch compatibility matrix

Signed-off-by: Christian Kadner <ckadner@us.ibm.com>

---------

Signed-off-by: Christian Kadner <ckadner@us.ibm.com>
2024-01-29 17:53:38 -08:00
Avelina9X
c94cd09744
Updated missing docstrings for args and returns in bert_padding.py (#795)
* Updated docstrings of bert_padding.py

Added docstrings for missing arguments in the unpad and pad methods.

* Update bert_padding.py

Fixed spelling mistakes
2024-01-27 09:16:25 -08:00
Tri Dao
ffc8682dd5 Add benchmarking code for Alibi (from Sanghun Cho) 2024-01-23 19:00:49 -08:00
Tao He
204c3c6d1b
Fixes an error in comment (#785)
Signed-off-by: Tao He <sighingnow@gmail.com>
2024-01-23 12:38:29 -08:00
Tri Dao
197f2083a2 Bump to v2.5.0 2024-01-22 23:40:10 -08:00
Tri Dao
54e80a3829 Implement page KV cache
Co-authored-by: ljss <450993438@qq.com>
2024-01-22 22:47:30 -08:00
Tri Dao
bdcae547c7 [LayerNorm] Don't exit early in the backward pass (fix #781) 2024-01-22 22:40:06 -08:00
Tri Dao
36bc29edf7 Use int64_t instead of uint32_t in kernel_traits.h 2024-01-22 22:39:29 -08:00
Tri Dao
000b67f5d8 Use int64_t instead of uint32_t for index_t 2024-01-22 11:25:50 -08:00
Tri Dao
e43a4ceaab [CI] Fix CUDA 12.2.2 compilation 2024-01-21 17:23:39 -08:00
Tri Dao
f9d7376126 Bump to v2.4.3 2024-01-21 17:14:37 -08:00
Tri Dao
0399432d68 [CI] Use CUDA 12.2.2 instead of 12.2.0 2024-01-21 15:35:57 -08:00
Tri Dao
ea8a25ca38 Remove configure in bwd kernel launch 2024-01-21 15:28:33 -08:00
Grigory Sizov
af01244ddd
Add split-kv and M<->H swap to varlen forward decoding attention (#754)
* Add split-k, M<->H to varseq path

* skip M<->H when dropout>0, fix LSE
2024-01-21 15:28:36 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss (#768) 2024-01-21 15:23:41 -08:00
Tri Dao
43ceab630b [CI] Use torch nightly 20231130 instead of 20231106 2024-01-20 22:31:04 -08:00
Tri Dao
8f4d82cf5e Update cutlass to v3.4.0 2024-01-20 22:30:06 -08:00
Tri Dao
395e5a0dba Move rotary device functions to a separate file 2024-01-20 18:01:18 -08:00
Tri Dao
3e2c827d9a Remove unused kernel_traits file 2024-01-20 17:41:44 -08:00
Tri Dao
66a127aef8 Refactor masking in fwd pass into 1 object 2024-01-20 17:39:53 -08:00
Tri Dao
ed4959b2eb Change inline to __forceinline__, use __grid_constant__ param 2024-01-20 17:38:47 -08:00
Tri Dao
6f706eff96 Make Softmax an object 2024-01-19 16:09:31 -08:00
Tri Dao
4ea866ca19 Make Alibi an object 2024-01-15 00:07:11 -08:00
Tri Dao
5aca153d6d Move bwd preprocess kernels to a separate file 2024-01-14 16:57:03 -08:00
Tri Dao
df1418f9db Move softmax_rescale_o to softmax.h 2024-01-14 15:06:06 -08:00
Tri Dao
6777336a1c Move masking to a separate file (mask.h) 2024-01-14 12:43:47 -08:00
Tri Dao
9448264ddd Remove seqq_parallel backward kernel that's not used 2024-01-14 12:25:49 -08:00
Tri Dao
1274ec3e7e Move dropout to a separate file (dropout.h) 2024-01-14 12:19:17 -08:00
Tri Dao
10dad61277 apply_dropout now takes tensor of rowcol layout 2024-01-14 01:03:23 -08:00
Tri Dao
d9cbcfb41c Remove dead code in philox.cuh 2024-01-13 02:02:03 -08:00
Tri Dao
a7b66ae25a Simplify writing softmax to gmem 2024-01-13 00:25:04 -08:00
Tri Dao
8d1b169ed1 Simplify SmemLayoutVtransposed in kernel_traits.h 2024-01-12 11:53:29 -08:00
Tri Dao
c9861a032d [LayerNorm] Initialize mean and rstd tensor using x.device 2024-01-09 16:30:31 -08:00
Erich Schubert
99ea4baa1d
Typo in README (#760) 2024-01-08 09:59:00 -08:00
Tri Dao
abbc131173 [LayerNorm] Switch from CUDA to Triton implementation 2024-01-05 00:31:17 -08:00
Tri Dao
f5b308e258 [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2 [LayerNorm] Implement parallel layer norm in Triton 2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5 [LayerNorm] Implement rowscale in Triton layernorm 2024-01-04 01:07:03 -08:00
jiaxingli
386e391117
Fix: implement deterministic backward in mha (#748)
* fix deterministic

* fix deterministic
2024-01-02 18:13:56 -08:00