Commit Graph

27 Commits

Author SHA1 Message Date
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
cb0daccc41 [FusedDense] Allow Row/ColumnParallelLinear to have uneven split 2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751 [FusedDense] Run black on fused_dense.py 2023-08-16 23:41:36 -07:00
Tri Dao
d2f4324f4c [LayerNorm] Make sure memory addresses are aligned to 16 bytes 2023-07-04 14:53:12 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f Implement GatedMlp 2023-04-18 03:37:14 -07:00
Tri Dao
6f6e9a9aaf [FusedDense] Enable sqrelu activation in FusedMLP 2023-04-13 15:29:32 -07:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Tri Dao
dc08ea1c33 Support H100 for other CUDA extensions 2023-03-15 16:59:27 -07:00
Tri Dao
eb33e587e9 [LayerNorm] Rename x1 -> residual 2023-01-19 13:07:27 -08:00
Tri Dao
88173a1aaf [FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP 2023-01-17 18:12:27 -08:00
Tri Dao
93383bd55b [TP] Implement TensorParallel without sequence parallel 2023-01-07 13:45:22 -08:00
Tri Dao
6738d9477d [LayerNorm] Implement RMS Norm 2023-01-06 17:34:22 -08:00
Tri Dao
1ec09ebd90 [FusedDense] Limit matrix dims to 2M (instead of 64k) 2023-01-01 17:06:39 -08:00
Tri Dao
65b4064b2a [FusedDense] Kick off input all_gather before weight dtype conversion 2022-12-31 22:47:34 -08:00
Tri Dao
a8cfe51551 Implement Tensor Parallel for transformer Block 2022-12-25 14:08:21 -08:00
Tri Dao
226a1b721d Implement TensorParallel for FusedDense and FusedDenseGeluDense 2022-12-24 11:48:56 -08:00
Tri Dao
e68ebbe89a Simplify FusedDense 2022-12-22 21:25:31 -08:00
Tri Dao
5fb6df0e04 Implement BERT 2022-12-18 21:47:27 -08:00
Tri Dao
5db330519a [LayerNorm] Support taking subset of input or subset of output 2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
Tri Dao
ece539abd6 Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00