Commit Graph

35 Commits

Author SHA1 Message Date
Tri Dao
c7f32a8409 [CrossEntropy] Support precomputed LSE 2024-09-08 09:24:43 -07:00
Tri Dao
d79f9b41a8 [CrossEntropy] Use online softmax to simplify implementation 2024-08-24 17:40:39 -07:00
Tri Dao
bcd918f275 [LayerNorm] Add option to write result to out and residual_out 2024-08-15 14:43:47 -07:00
Tri Dao
bd82d6c6eb Revert "[LayerNorm] Don't store x + residual if we don't need gradients"
This reverts commit 800401847e.
2024-08-15 12:02:39 -07:00
Tri Dao
800401847e [LayerNorm] Don't store x + residual if we don't need gradients 2024-08-15 11:08:46 -07:00
lancerts
22339db185
remove an unused import (#960) 2024-05-23 11:12:31 -07:00
Tri Dao
ec6d22143b [CrossEntropy] Change ignored_index -> ignore_index 2024-04-26 10:50:41 -07:00
Ivan Komarov
f692b98d80
Fix spurious re-compilations of rotary_kernel (#911)
All integer parameters are specialized by default, so the two parameters
removed in this commit could lead to kernel re-compilation, even if
they were completely unused.
2024-04-05 13:40:41 -07:00
Tri Dao
36587c01cb [LayerNorm] Update layer_norm_linear 2024-03-18 23:15:33 -07:00
Tri Dao
bdcae547c7 [LayerNorm] Don't exit early in the backward pass (fix #781) 2024-01-22 22:40:06 -08:00
Curtis "Fjord" Hawthorne
d8aacc510c
return z_loss (#768) 2024-01-21 15:23:41 -08:00
Tri Dao
c9861a032d [LayerNorm] Initialize mean and rstd tensor using x.device 2024-01-09 16:30:31 -08:00
Tri Dao
f5b308e258 [LayerNorm] Rename layernorm.py -> layer_norm.py 2024-01-05 00:21:03 -08:00
Tri Dao
665b55e2e2 [LayerNorm] Implement parallel layer norm in Triton 2024-01-04 23:15:35 -08:00
Tri Dao
aa5c6438c5 [LayerNorm] Implement rowscale in Triton layernorm 2024-01-04 01:07:03 -08:00
Tri Dao
cd089597fd [LayerNorm] Implement dropout in fused residual + LN/RMSNorm 2023-12-19 16:26:07 -08:00
Tri Dao
08124c8f9c [CrossEntropy] Implement logit_scale option 2023-12-16 18:39:37 -08:00
Tri Dao
9356a1c038 [LayerNorm] Implement layer_norm_linear 2023-11-30 21:46:07 -08:00
Tri Dao
aaa1474129 [CrossEntropy] Simplify the case of large vocab with Tensor Parallel 2023-11-19 23:19:36 -08:00
Shijie
abf04a56e1
fix flash ce mp large vocab (#673) 2023-11-19 23:01:07 -08:00
Tri Dao
017716451d [LayerNorm] Add postnorm residual + LayerNorm/RMSNorm in Triton 2023-11-13 22:37:55 -08:00
Tri Dao
79bd1a2d5d [LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton 2023-11-13 02:04:49 -08:00
Tri Dao
c79de85ffa [CrossEntropy] Fix triton cross_entropy_loss IMA for >=2B elements 2023-10-24 00:17:34 -07:00
Tri Dao
5400fdc4ac [CE] Implement CrossEntropyLoss in Triton 2023-09-15 20:05:28 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Tri Dao
9795159082 [Rotary] Set device before launching Triton kernel to avoid error 2023-09-05 21:29:03 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir (#516) 2023-09-05 11:29:03 -07:00
Tri Dao
b28ec236df [Rotary] Implement varlen rotary 2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d [Rotary] Clean up rotary Triton implementation a bit 2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1 [Rotary] Speed up rotary kernel when interleaved=True 2023-09-03 16:24:37 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00