Tri Dao
|
62e9814466
|
[Rotary] Make sure frequency calculation is in fp32
|
2023-07-02 16:39:39 -07:00 |
|
Tri Dao
|
48bc6eacd6
|
[Gen] Add rotary base as an argument to FT attention kernel
|
2023-05-30 13:38:34 -07:00 |
|
Tri Dao
|
a9a4b4e4f2
|
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
|
2023-05-04 23:39:43 -07:00 |
|
Tri Dao
|
311d6606bf
|
[Gen] Fix FT kernel smem size, CG when batch size changed
|
2023-04-20 17:03:13 -07:00 |
|
Tri Dao
|
96d10f6545
|
Implement LLaMa
|
2023-04-18 21:51:35 -07:00 |
|
Tri Dao
|
605655bc66
|
[Gen] Fix FT kernel when using CG
|
2023-04-14 16:50:01 -07:00 |
|
Tri Dao
|
393882bc08
|
[LayerNorm] Implement LN with parallel residual, support dim 8k
|
2023-03-31 14:23:45 -07:00 |
|
Tri Dao
|
993d12448e
|
Implement GPT-NeoX
|
2023-03-29 01:21:25 -07:00 |
|
Tri Dao
|
f5d0fbd468
|
[FT] Fix FT's single query attention for bf16 hdim128 rotary
|
2023-03-28 21:27:00 -07:00 |
|
Tri Dao
|
4d87e4d875
|
Implement GPT-J
|
2023-03-22 16:16:58 -07:00 |
|
Tri Dao
|
e45a46a5b7
|
[Rotary] Implement GPT-J style (interleaved) rotary
|
2023-03-14 14:35:53 -07:00 |
|
Tri Dao
|
78b7a1dc18
|
[OPT] Load fp16 weights on CPU before moving to GPU
|
2023-01-22 17:01:32 -08:00 |
|
Tri Dao
|
f68d41ec77
|
[Gen] Add OPT to generation test
|
2023-01-17 19:59:06 -08:00 |
|
Tri Dao
|
88173a1aaf
|
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
|
2023-01-17 18:12:27 -08:00 |
|
Tri Dao
|
780e8eeabb
|
[ViT] Support timm checkpoint, add tests
|
2023-01-16 01:20:34 -08:00 |
|
Tri Dao
|
ff34123bd4
|
Reorder LN in Block, support OPT
|
2023-01-15 22:14:31 -08:00 |
|
Tri Dao
|
f1e01c27ba
|
[Gen] Pass qkv_stride to ft_attention kernel for batched generation
|
2023-01-15 15:20:01 -08:00 |
|
Tri Dao
|
7c2191542a
|
[Gen] Make generation work with Tensor Parallel
|
2023-01-15 11:34:27 -08:00 |
|
Tri Dao
|
b48599002a
|
[Gen] Add timing option
|
2023-01-07 19:05:09 -08:00 |
|
Tri Dao
|
0938298e4c
|
[Gen] Adjust shape of kv_cache when using FT
|
2023-01-07 17:27:54 -08:00 |
|
Tri Dao
|
e02fd588aa
|
[Gen] Implement top-k and top-p sampling
|
2023-01-07 17:00:02 -08:00 |
|
Tri Dao
|
11be742aa3
|
[Gen] Test generation with rotary embedding
|
2023-01-07 14:37:54 -08:00 |
|
Tri Dao
|
93383bd55b
|
[TP] Implement TensorParallel without sequence parallel
|
2023-01-07 13:45:22 -08:00 |
|
Tri Dao
|
6738d9477d
|
[LayerNorm] Implement RMS Norm
|
2023-01-06 17:34:22 -08:00 |
|
Tri Dao
|
a668890fcd
|
[Gen] Add option to run generation with FT attention kernel
|
2023-01-03 22:10:31 -08:00 |
|
Tri Dao
|
ef1ba918c6
|
[GPT] Refactor function to shard state_dict for TensorParallel
|
2023-01-01 00:09:33 -08:00 |
|
Tri Dao
|
63670fd84a
|
Implement generation for GPT
|
2022-12-27 21:01:50 -08:00 |
|
Tri Dao
|
9d797d8848
|
Support loading GPT2 weights from Huggingface
|
2022-12-27 11:22:48 -08:00 |
|
Tri Dao
|
c6ecd40a59
|
Tweak CrossEntropyLoss to take process_group in init
|
2022-12-27 10:47:43 -08:00 |
|
Tri Dao
|
b4018a5028
|
Implement Tensor Parallel for GPT model
|
2022-12-26 16:22:43 -08:00 |
|
Tri Dao
|
78225c5366
|
Implement Tensor Parallel for GPT2Embeddings
|
2022-12-25 14:29:53 -08:00 |
|
Tri Dao
|
a8cfe51551
|
Implement Tensor Parallel for transformer Block
|
2022-12-25 14:08:21 -08:00 |
|
Tri Dao
|
1e712ea8b0
|
Implement TensorParallel for MHA
|
2022-12-25 11:39:55 -08:00 |
|
Tri Dao
|
226a1b721d
|
Implement TensorParallel for FusedDense and FusedDenseGeluDense
|
2022-12-24 11:48:56 -08:00 |
|
Tri Dao
|
dff68c2b22
|
Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss
|
2022-12-23 14:51:08 -08:00 |
|
Tri Dao
|
e68ebbe89a
|
Simplify FusedDense
|
2022-12-22 21:25:31 -08:00 |
|
Tri Dao
|
13cdceb377
|
Implement last_layer_subset optimization for BERT
|
2022-12-19 22:18:46 -08:00 |
|
Tri Dao
|
5fb6df0e04
|
Implement BERT
|
2022-12-18 21:47:27 -08:00 |
|
Tri Dao
|
5db330519a
|
[LayerNorm] Support taking subset of input or subset of output
|
2022-12-12 22:16:14 -08:00 |
|
Tri Dao
|
ae137ed17a
|
[LayerNorm] Fuse LayerScale
|
2022-12-10 23:28:23 -08:00 |
|
Tri Dao
|
8c6609ae1a
|
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
|
2022-12-09 02:06:22 -08:00 |
|
Tri Dao
|
d4b320b31f
|
Add MLP, MHA, Block, Embedding modules
|
2022-11-13 22:06:44 -08:00 |
|
Tri Dao
|
fa6d1ce44f
|
Add fused_dense and dropout_add_layernorm CUDA extensions
|
2022-11-13 21:59:20 -08:00 |
|
Tri Dao
|
343492ec30
|
Make nccl operations async in CrossEntropyLossParallel
|
2022-11-13 17:27:26 -08:00 |
|
Tri Dao
|
a8fec99a9a
|
Skip flash_attn_split test
|
2022-11-13 12:27:48 -08:00 |
|
Tri Dao
|
9d3116addf
|
Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
|
2022-11-13 12:21:51 -08:00 |
|
Tri Dao
|
7c9953815a
|
Add fused cross entropy loss
|
2022-11-12 21:58:41 -08:00 |
|
Tri Dao
|
6998e0ecdb
|
Fix out-of-bound memory read
|
2022-11-09 09:34:14 -08:00 |
|
Tri Dao
|
7479757191
|
Fix pipelining bug in Triton bwd with bias_type=matrix
|
2022-11-06 11:50:35 -08:00 |
|
Tri Dao
|
557781933d
|
Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
|
2022-11-05 16:26:17 -07:00 |
|
Tri Dao
|
ca81f32e04
|
Implement rotary embedding in CUDA
|
2022-11-04 22:42:01 -07:00 |
|
Tri Dao
|
ff78ea4123
|
Fix race condition in Triton bwd when there's bias
|
2022-11-04 11:20:27 -07:00 |
|
Tri Dao
|
86862cfd7b
|
Implement attention bias for Triton version
|
2022-11-04 10:33:54 -07:00 |
|
Tri Dao
|
aacc10fbab
|
Fix race condition in Triton bwd for non-po2 headdims
|
2022-11-02 07:32:54 -07:00 |
|
Tri Dao
|
1fb12afdfb
|
Avoid memcpy in the Triton bwd
|
2022-11-01 15:06:45 -07:00 |
|
Tri Dao
|
9b0bc97872
|
Fix race condition in Triton fwd
|
2022-10-31 14:34:57 -07:00 |
|
Tri Dao
|
4f81aff46e
|
Add debug_barrier for all headdims in Triton bwd
|
2022-10-31 01:25:02 -07:00 |
|
Tri Dao
|
e78d509c64
|
[WIP] Support all head dimensions up to 128 in the Triton bwd
WIP because there seems to be some race conditions for head dimensions other
than 16, 32, 64, 128.
|
2022-10-31 00:46:22 -07:00 |
|
Tri Dao
|
008951f1d9
|
Support all head dimensions up to 128 in the Triton fwd
|
2022-10-30 22:10:48 -07:00 |
|
Tri Dao
|
b910bf14c1
|
Support arbitrary seqlens (both q & k) in Triton bwd
|
2022-10-30 21:50:53 -07:00 |
|
Tri Dao
|
dc55469355
|
Support arbitrary seqlen_k in Triton bwd
|
2022-10-30 21:26:26 -07:00 |
|
Tri Dao
|
d11341fd1a
|
Fix Triton fwd to support seqlen not multiples of 128
|
2022-10-30 19:05:47 -07:00 |
|
Tri Dao
|
b0c0db81f6
|
Implement FlashAttention in Triton
|
2022-10-30 18:09:11 -07:00 |
|
Tri Dao
|
46fd2a20b2
|
Support all head dims that are multiples of 8, up to 128
|
2022-10-24 16:04:21 -07:00 |
|
Tri Dao
|
a5a8806d1a
|
Split bwd on the seqlen_q dimension
|
2022-10-23 11:35:15 -07:00 |
|
Tri Dao
|
1aa6d7d9b6
|
Rework dropout to decouple forward and backward
They don't have to have the same block size, number of threads, etc.
|
2022-10-21 12:04:27 -07:00 |
|
Tri Dao
|
52fb4b729b
|
Fix #54: set device for multi-GPU case
|
2022-10-16 12:51:26 -07:00 |
|
Tri Dao
|
5badfb7848
|
Implement attention kernel that splits the batch into two
|
2022-10-13 20:49:02 -07:00 |
|
Tri Dao
|
0c01568daf
|
Only run backward test for d=128 on A100
|
2022-10-04 18:06:08 -07:00 |
|
Tri Dao
|
2ed471ecc4
|
Add tests for numerical error
|
2022-07-22 17:54:09 -04:00 |
|