Tri Dao
d79f9b41a8
[CrossEntropy] Use online softmax to simplify implementation
2024-08-24 17:40:39 -07:00
Tri Dao
bcd918f275
[LayerNorm] Add option to write result to out and residual_out
2024-08-15 14:43:47 -07:00
Tri Dao
bd82d6c6eb
Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
...
This reverts commit 800401847e .
2024-08-15 12:02:39 -07:00
Tri Dao
800401847e
[LayerNorm] Don't store x + residual if we don't need gradients
2024-08-15 11:08:46 -07:00
lancerts
22339db185
remove an unused import ( #960 )
2024-05-23 11:12:31 -07:00
Tri Dao
ec6d22143b
[CrossEntropy] Change ignored_index -> ignore_index
2024-04-26 10:50:41 -07:00
Ivan Komarov
f692b98d80
Fix spurious re-compilations of rotary_kernel ( #911 )
...
All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
2024-04-05 13:40:41 -07:00
Tri Dao
36587c01cb
[LayerNorm] Update layer_norm_linear
2024-03-18 23:15:33 -07:00
Tri Dao
bdcae547c7
[LayerNorm] Don't exit early in the backward pass ( fix #781 )
2024-01-22 22:40:06 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss ( #768 )
2024-01-21 15:23:41 -08:00
Tri Dao
c9861a032d
[LayerNorm] Initialize mean and rstd tensor using x.device
2024-01-09 16:30:31 -08:00
Tri Dao
f5b308e258
[LayerNorm] Rename layernorm.py -> layer_norm.py
2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2
[LayerNorm] Implement parallel layer norm in Triton
2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5
[LayerNorm] Implement rowscale in Triton layernorm
2024-01-04 01:07:03 -08:00
Tri Dao
cd089597fd
[LayerNorm] Implement dropout in fused residual + LN/RMSNorm
2023-12-19 16:26:07 -08:00
Tri Dao
08124c8f9c
[CrossEntropy] Implement logit_scale option
2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038
[LayerNorm] Implement layer_norm_linear
2023-11-30 21:46:07 -08:00
Tri Dao
aaa1474129
[CrossEntropy] Simplify the case of large vocab with Tensor Parallel
2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab ( #673 )
2023-11-19 23:01:07 -08:00
Tri Dao
017716451d
[LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton
2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d
[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton
2023-11-13 02:04:49 -08:00
Tri Dao
c79de85ffa
[CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements
2023-10-24 00:17:34 -07:00
Tri Dao
5400fdc4ac
[CE] Implement CrossEntropyLoss in Triton
2023-09-15 20:05:28 -07:00
Tri Dao
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
2023-09-10 17:22:37 -07:00
Tri Dao
9795159082
[Rotary] Set device before launching Triton kernel to avoid error
2023-09-05 21:29:03 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir ( #516 )
2023-09-05 11:29:03 -07:00
Tri Dao
b28ec236df
[Rotary] Implement varlen rotary
2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d
[Rotary] Clean up rotary Triton implementation a bit
2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1
[Rotary] Speed up rotary kernel when interleaved=True
2023-09-03 16:24:37 -07:00
Tri Dao
942fcbf046
[Rotary] Implement rotary in Triton
2023-09-03 02:51:58 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00
Tri Dao
6f6e9a9aaf
[FusedDense] Enable sqrelu activation in FusedMLP
2023-04-13 15:29:32 -07:00
Tri Dao
2e33fc8e36
Add GPT and ViT models
2022-11-13 22:30:23 -08:00