Tri Dao
8c20cfef49
[Rotary] Support qkv block layout from GQA
2024-09-11 10:39:58 -07:00
Tri Dao
c7f32a8409
[CrossEntropy] Support precomputed LSE
2024-09-08 09:24:43 -07:00
Tri Dao
d79f9b41a8
[CrossEntropy] Use online softmax to simplify implementation
2024-08-24 17:40:39 -07:00
Tri Dao
bcd918f275
[LayerNorm] Add option to write result to out and residual_out
2024-08-15 14:43:47 -07:00
Tri Dao
bd82d6c6eb
Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
...
This reverts commit 800401847e .
2024-08-15 12:02:39 -07:00
Tri Dao
800401847e
[LayerNorm] Don't store x + residual if we don't need gradients
2024-08-15 11:08:46 -07:00
lancerts
22339db185
remove an unused import ( #960 )
2024-05-23 11:12:31 -07:00
Tri Dao
ec6d22143b
[CrossEntropy] Change ignored_index -> ignore_index
2024-04-26 10:50:41 -07:00
Ivan Komarov
f692b98d80
Fix spurious re-compilations of rotary_kernel ( #911 )
...
All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
2024-04-05 13:40:41 -07:00
Tri Dao
36587c01cb
[LayerNorm] Update layer_norm_linear
2024-03-18 23:15:33 -07:00
Tri Dao
bdcae547c7
[LayerNorm] Don't exit early in the backward pass ( fix #781 )
2024-01-22 22:40:06 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss ( #768 )
2024-01-21 15:23:41 -08:00
Tri Dao
c9861a032d
[LayerNorm] Initialize mean and rstd tensor using x.device
2024-01-09 16:30:31 -08:00
Tri Dao
f5b308e258
[LayerNorm] Rename layernorm.py -> layer_norm.py
2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2
[LayerNorm] Implement parallel layer norm in Triton
2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5
[LayerNorm] Implement rowscale in Triton layernorm
2024-01-04 01:07:03 -08:00
Tri Dao
cd089597fd
[LayerNorm] Implement dropout in fused residual + LN/RMSNorm
2023-12-19 16:26:07 -08:00
Tri Dao
08124c8f9c
[CrossEntropy] Implement logit_scale option
2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038
[LayerNorm] Implement layer_norm_linear
2023-11-30 21:46:07 -08:00
Tri Dao
aaa1474129
[CrossEntropy] Simplify the case of large vocab with Tensor Parallel
2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab ( #673 )
2023-11-19 23:01:07 -08:00
Tri Dao
017716451d
[LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton
2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d
[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton
2023-11-13 02:04:49 -08:00
Tri Dao
c79de85ffa
[CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements
2023-10-24 00:17:34 -07:00
Tri Dao
5400fdc4ac
[CE] Implement CrossEntropyLoss in Triton
2023-09-15 20:05:28 -07:00
Tri Dao
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
2023-09-10 17:22:37 -07:00
Tri Dao
9795159082
[Rotary] Set device before launching Triton kernel to avoid error
2023-09-05 21:29:03 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir ( #516 )
2023-09-05 11:29:03 -07:00
Tri Dao
3557e0bb8f
[MLP] Implement SwiGLU with torch jiterator
2023-09-04 15:43:53 -07:00
Tri Dao
b28ec236df
[Rotary] Implement varlen rotary
2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d
[Rotary] Clean up rotary Triton implementation a bit
2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1
[Rotary] Speed up rotary kernel when interleaved=True
2023-09-03 16:24:37 -07:00
Tri Dao
942fcbf046
[Rotary] Implement rotary in Triton
2023-09-03 02:51:58 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 ( #502 )
...
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 ( #461 )
...
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
cb0daccc41
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
d2f4324f4c
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
2023-07-04 14:53:12 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f
Implement GatedMlp
2023-04-18 03:37:14 -07:00
Tri Dao
6f6e9a9aaf
[FusedDense] Enable sqrelu activation in FusedMLP
2023-04-13 15:29:32 -07:00
Tri Dao
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
2023-03-31 14:23:45 -07:00
Tri Dao
dc08ea1c33
Support H100 for other CUDA extensions
2023-03-15 16:59:27 -07:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
93383bd55b
[TP] Implement TensorParallel without sequence parallel
2023-01-07 13:45:22 -08:00
Tri Dao
6738d9477d
[LayerNorm] Implement RMS Norm
2023-01-06 17:34:22 -08:00
Tri Dao
1ec09ebd90
[FusedDense] Limit matrix dims to 2M (instead of 64k)
2023-01-01 17:06:39 -08:00
Tri Dao
65b4064b2a
[FusedDense] Kick off input all_gather before weight dtype conversion
2022-12-31 22:47:34 -08:00