Commit Graph

553 Commits

Author SHA1 Message Date
Tri Dao
2c3baba4a6 Bump to v2.3.4 2023-11-19 23:21:31 -08:00
Tri Dao
aaa1474129 [CrossEntropy] Simplify the case of large vocab with Tensor Parallel 2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab (#673) 2023-11-19 23:01:07 -08:00
Tri Dao
db2f80692c Write zero to out / grad if seqlen_q or seqlen_k is zero 2023-11-19 22:20:01 -08:00
Tri Dao
43bb6d8aaa Update cutlass to 3.2.2 2023-11-19 21:43:48 -08:00
Driss Guessous
dc4b9ad6c4
add checks (#640) 2023-11-19 20:43:27 -08:00
Tri Dao
017716451d [LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton 2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Antony Frolov
3566596ad8
Fix typo in RotaryEmbedding forward output type (#666) 2023-11-09 11:43:02 -08:00
Tri Dao
83aef842be Bump to v2.3.3 2023-10-24 00:24:07 -07:00
Tri Dao
c79de85ffa [CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements 2023-10-24 00:17:34 -07:00
Tri Dao
02ac572f3f Clarify inference README is a placeholder 2023-10-12 10:14:58 -07:00
Tri Dao
7f31e7c16a Bump to v2.3.2 2023-10-08 17:21:29 -07:00
Tri Dao
5a83425442 Change constexpr int to constexpr static int 2023-10-08 16:26:33 -07:00
Tri Dao
3a9fe7b0fa Add change log 2023-10-05 14:19:08 -07:00
Tri Dao
aa4fd2d166 Clarify that Windows is not supported right now 2023-10-05 14:00:45 -07:00
Tri Dao
5e525a8dc8 [CI] Use official Pytorch 2.1, add CUDA 11.8 for Pytorch 2.1 2023-10-03 22:20:30 -07:00
Tri Dao
21c3b0d8f6 Bump to v2.3.1 2023-10-03 19:56:45 -07:00
Tri Dao
e279bf8ed9 [Gen] Accept cache_batch_idx to index into the KV cache 2023-10-03 16:27:26 -07:00
Tri Dao
601b4dc48d Bump to v2.3.0 2023-09-26 22:08:29 -07:00
Tri Dao
083e8f525f Implement local attention
Co-authored-by: Timothee Lacroix <t@mistral.ai>
2023-09-26 16:31:08 -07:00
Katherine Crowson
4c8ff9154e
Fix NameError and typo in ApplyRotaryEmbQKV_ (#569) 2023-09-25 10:47:34 -07:00
Tri Dao
0a1d03c7ea Bump to v2.2.5 2023-09-24 00:54:03 -07:00
Tri Dao
812cb1c990 Switch cutlass to newer commit to avoid compilation warning 2023-09-24 00:42:50 -07:00
Tri Dao
65c234ed90 Don't over-allocate dq_accum in case of varlen 2023-09-24 00:36:07 -07:00
Tri Dao
1879e089c7 Reduce number of templates for headdim > 128 2023-09-23 22:24:30 -07:00
Tri Dao
dd9a6fa45a Add placeholder for inference example 2023-09-22 02:31:00 -07:00
Tri Dao
bff3147175 Re-enable compilation for Hopper 2023-09-21 23:55:25 -07:00
Yuchao Dai
187c2a0635
Fix E1136 (#563) 2023-09-21 11:48:23 -07:00
Tri Dao
229080b9d2 Bump to v2.2.4 2023-09-20 23:39:38 -07:00
Tri Dao
2d8ea9a530 Swap seqlen_q and ngroups when seqlen_q=1 (h/t Daniel Haziza) 2023-09-20 23:38:22 -07:00
Tri Dao
0705d2718d [Llama] Fix some tests, add tests for Llama 2 and CodeLlama 2023-09-20 23:36:46 -07:00
Tri Dao
e0fbaa7016 [Gen] Simplify decode_speculative 2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489 [Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset 2023-09-19 22:20:22 -07:00
Kevin Hu
42832575d4
Fix Llama GQA/MQA (#546)
* Fix llama MQA

* Fix permute shape

* Update llama.py
2023-09-19 22:15:59 -07:00
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Tri Dao
3250ff3d82 Swap seqlen_q, nheads for MQA when seqlen_q=1 for fwd (h/t Daniel H) 2023-09-18 14:52:16 -07:00
Tri Dao
43617deab9 Remove template for (IsEvenMN=T, IsEvenK=F) to speed up compilation 2023-09-18 12:21:36 -07:00
Federico Berto
fa3ddcbaaa
[Minor] add nvcc note on bare_metal_version RuntimeError (#552)
* Add nvcc note on bare_metal_version `RuntimeError`

* Run Black formatting
2023-09-18 11:48:15 -07:00
Tri Dao
799f56fa90 Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults 2023-09-17 22:15:38 -07:00
Tri Dao
c984208ddb Set block size to 64 x 64 for kvcache to avoid nvcc segfaults 2023-09-17 16:14:58 -07:00
Tri Dao
8c8b4d36e1 Bump to v2.2.3 2023-09-16 01:47:01 -07:00
Tri Dao
ccbb14f38e Implement rotary embedding in flash_attn_with_kvcache 2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac [CE] Implement CrossEntropyLoss in Triton 2023-09-15 20:05:28 -07:00
Tri Dao
56b7fc6ee0 Simplify the implementation of KVcache attn by appending KV first 2023-09-13 15:55:48 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Tri Dao
bb9beb3645 Remove some unused headers 2023-09-12 12:37:10 -07:00
Tri Dao
08c295c043 Bump to v2.2.2 2023-09-10 23:48:12 -07:00
Tri Dao
ee77b931b9 Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza) 2023-09-10 22:56:33 -07:00
Kevin Hu
07005806ff
Add BigCode converters (#532) 2023-09-10 17:24:50 -07:00