Commit Graph

524 Commits

Author SHA1 Message Date
Tri Dao
732654583c Implement deterministic backward (thanks to Meituan) 2023-12-23 17:57:36 -08:00
Tri Dao
2c7d7b7396 Implement norm head for Baichuan2 2023-12-22 16:55:40 -08:00
Tri Dao
68f178aa4b [CI] Don't compile for python 3.7 pytorch 2.2 2023-12-22 10:10:02 -08:00
Tri Dao
7316277303 Bump to v2.4.0 2023-12-22 00:09:53 -08:00
Tri Dao
50d144c906 Mention Alibi in README 2023-12-21 23:48:16 -08:00
Tri Dao
8448c02889 Update cutlass to v3.3.0 2023-12-21 23:25:50 -08:00
Tri Dao
c3b2196652 Add Alibi to MHA, test with Baichuan-13B 2023-12-21 22:49:55 -08:00
Tri Dao
701b51bfc3 [CI] Use torch-nightly 20231106 instead of 20231127 2023-12-21 22:28:09 -08:00
Tri Dao
5ab9b3667b Clean up alibi, implement non-causal alibi 2023-12-21 22:27:40 -08:00
Tri Dao
bc28eacc60 Format flash_attn_interface.py 2023-12-19 23:13:53 -08:00
Tri Dao
0a146185d6 [Gen] Remove minor dead code 2023-12-19 22:57:39 -08:00
Sanghun Cho
e4f726fc44
Support alibi, by Sanghun Cho from Kakao Brain
* hard-code alibi in fwd

* use params.h as hun_heads

* hard-code alibi in bwd

* add alibi on/off option

* compute alibi_start, ratio outside of kernels

* fix minor merge conflict

* add test_alibi.py

* change apply_alibi() location before masking

* add alibi in splitkv kernel

* fix backward func # of returns

* add out-of-bound check in apply_alibi()

* update test_alibi.py

* update test_alibi.py for kvcache

* simplify alibi parameter interface

* fix performance issue
by computing alibi outside of branch

* update test_flash_attn_varlen_func() for left padding

* implement alibi_slopes (b, nh) loading

* optimize apply_alibi() a bit

* update test cases for alibi_slopes loading

* reflect stylistic comments

* disable "seqlenq_ngroups_swapped" when using alibi

---------

Co-authored-by: monk.detective <monk.detective@kakaobrain.com>
2023-12-19 22:56:06 -08:00
Tri Dao
cd089597fd [LayerNorm] Implement dropout in fused residual + LN/RMSNorm 2023-12-19 16:26:07 -08:00
Tri Dao
713bd3aa9a [CrossEntropy] Test longer sequences 2023-12-16 19:11:23 -08:00
Tri Dao
08124c8f9c [CrossEntropy] Implement logit_scale option 2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038 [LayerNorm] Implement layer_norm_linear 2023-11-30 21:46:07 -08:00
Tri Dao
92dd5703ec Bump to v2.3.6 2023-11-27 16:23:39 -08:00
Tri Dao
d4a7c8ffbb [CI] Only compile for CUDA 11.8 & 12.2, MAX_JOBS=2,add torch-nightly 2023-11-27 16:21:28 -08:00
Jeremy Reizenstein
ce3e7280f8
Allow varlen_fwd to take optional seqused_k (#647)
Co-authored-by: bottler <bottler@users.noreply.github.com>
2023-11-27 00:41:23 -08:00
Tri Dao
23b77c8148 Bump to v2.3.5 2023-11-26 19:08:28 -08:00
Tri Dao
b4bf9cc1f3 Fix performance regression with causal 2023-11-26 19:07:25 -08:00
Tri Dao
2c3baba4a6 Bump to v2.3.4 2023-11-19 23:21:31 -08:00
Tri Dao
aaa1474129 [CrossEntropy] Simplify the case of large vocab with Tensor Parallel 2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab (#673) 2023-11-19 23:01:07 -08:00
Tri Dao
db2f80692c Write zero to out / grad if seqlen_q or seqlen_k is zero 2023-11-19 22:20:01 -08:00
Tri Dao
43bb6d8aaa Update cutlass to 3.2.2 2023-11-19 21:43:48 -08:00
Driss Guessous
dc4b9ad6c4
add checks (#640) 2023-11-19 20:43:27 -08:00
Tri Dao
017716451d [LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton 2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Antony Frolov
3566596ad8
Fix typo in RotaryEmbedding forward output type (#666) 2023-11-09 11:43:02 -08:00
Tri Dao
83aef842be Bump to v2.3.3 2023-10-24 00:24:07 -07:00
Tri Dao
c79de85ffa [CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements 2023-10-24 00:17:34 -07:00
Tri Dao
02ac572f3f Clarify inference README is a placeholder 2023-10-12 10:14:58 -07:00
Tri Dao
7f31e7c16a Bump to v2.3.2 2023-10-08 17:21:29 -07:00
Tri Dao
5a83425442 Change constexpr int to constexpr static int 2023-10-08 16:26:33 -07:00
Tri Dao
3a9fe7b0fa Add change log 2023-10-05 14:19:08 -07:00
Tri Dao
aa4fd2d166 Clarify that Windows is not supported right now 2023-10-05 14:00:45 -07:00
Tri Dao
5e525a8dc8 [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 2023-10-03 22:20:30 -07:00
Tri Dao
21c3b0d8f6 Bump to v2.3.1 2023-10-03 19:56:45 -07:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
601b4dc48d Bump to v2.3.0 2023-09-26 22:08:29 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Katherine Crowson
4c8ff9154e
Fix NameError and typo in ApplyRotaryEmbQKV_ (#569) 2023-09-25 10:47:34 -07:00
Tri Dao
0a1d03c7ea Bump to v2.2.5 2023-09-24 00:54:03 -07:00
Tri Dao
812cb1c990 Switch cutlass to newer commit to avoid compilation warning 2023-09-24 00:42:50 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
1879e089c7 Reduce number of templates for headdim > 128 2023-09-23 22:24:30 -07:00
Tri Dao
dd9a6fa45a Add placeholder for inference example 2023-09-22 02:31:00 -07:00
Tri Dao
bff3147175 Re-enable compilation for Hopper 2023-09-21 23:55:25 -07:00
Yuchao Dai
187c2a0635
Fix E1136 (#563) 2023-09-21 11:48:23 -07:00