Commit Graph

176 Commits

Author SHA1 Message Date
Tri Dao
63670fd84a Implement generation for GPT 2022-12-27 21:01:50 -08:00
Tri Dao
9d797d8848 Support loading GPT2 weights from Huggingface 2022-12-27 11:22:48 -08:00
Tri Dao
c6ecd40a59 Tweak CrossEntropyLoss to take process_group in init 2022-12-27 10:47:43 -08:00
Tri Dao
b4018a5028 Implement Tensor Parallel for GPT model 2022-12-26 16:22:43 -08:00
Tri Dao
78225c5366 Implement Tensor Parallel for GPT2Embeddings 2022-12-25 14:29:53 -08:00
Tri Dao
a8cfe51551 Implement Tensor Parallel for transformer Block 2022-12-25 14:08:21 -08:00
Tri Dao
1e712ea8b0 Implement TensorParallel for MHA 2022-12-25 11:39:55 -08:00
Tri Dao
226a1b721d Implement TensorParallel for FusedDense and FusedDenseGeluDense 2022-12-24 11:48:56 -08:00
Tri Dao
dff68c2b22 Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss 2022-12-23 14:51:08 -08:00
Tri Dao
e68ebbe89a Simplify FusedDense 2022-12-22 21:25:31 -08:00
Tri Dao
1bc6e5b09c Bump to v0.2.5 2022-12-21 14:33:18 -08:00
Tri Dao
496e4f528c Implement XPos (Sun et al.) 2022-12-21 14:17:58 -08:00
Tri Dao
c2407dec96 Fix typo in config: train.gpu -> train.gpu_mem 2022-12-21 13:42:30 -08:00
Tri Dao
13cdceb377 Implement last_layer_subset optimization for BERT 2022-12-19 22:18:46 -08:00
Tri Dao
5fb6df0e04 Implement BERT 2022-12-18 21:47:27 -08:00
Tri Dao
dc24c22603
Merge pull request #92 from ploshkin/rm-shape-asserts
Fix slicing dimensions in rotary embeddings
2022-12-17 11:22:06 -08:00
Alexander Ploshkin
ee8984d2be add asserts for sin shape 2022-12-17 13:34:57 +04:00
Alexander Ploshkin
c7c66976cc fix slicing dimensions 2022-12-16 15:39:06 +04:00
Tri Dao
b78f5a392d [Docs] Mention Megatron-LM 2022-12-15 19:49:04 -08:00
Tri Dao
ece8f05d09 [Docs] Mention PubMedGPT 2022-12-15 19:44:59 -08:00
Alexander Ploshkin
96656b9323 Remove redundant shape asserts in rotary embeddings 2022-12-15 18:13:21 +04:00
Tri Dao
04c4c6106e Bump to v0.2.4 2022-12-14 14:49:26 -08:00
Tri Dao
6b5f271c6d [Triton] Avoid einops repeat by using Tensor.expand 2022-12-14 14:48:41 -08:00
Tri Dao
88c4e5dbf6 Fix the case when dout is not contiguous 2022-12-13 13:58:17 -08:00
Tri Dao
a1a5d2ee49 Bump to v0.2.3 2022-12-13 01:37:02 -08:00
Tri Dao
5db330519a [LayerNorm] Support taking subset of input or subset of output 2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
Tri Dao
8a2ece89f7 Simplify BOOL_SWITCH macro to fix compiling error on gcc 7 2022-12-06 14:38:32 -08:00
Tri Dao
a84d07283c [Docs] Mention FasterTransformer integration 2022-12-05 00:34:09 -08:00
Tri Dao
4a6eaa9f27 Update configs, add results 2022-11-29 04:46:43 -08:00
Tri Dao
0bf5e50038 Release training code 2022-11-28 17:34:40 -08:00
Tri Dao
9bc63d1e2d Fix typo in comments 2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d Speed up compilation by splitting into separate .cu files 2022-11-25 16:30:18 -08:00
Tri Dao
b784ed73cf [Docs] Clarify OpenFold speedup 2022-11-25 10:49:17 -08:00
Tri Dao
d9021ae4ec [Docs] Mention OpenFold 2022-11-23 13:01:19 -08:00
Tri Dao
1feb94265c [ViT] Use dropout_add_ln for the 1st layer norm 2022-11-23 12:48:56 -08:00
Tri Dao
45bcf37b97 [Docs] Capitalize the bibtex citation 2022-11-22 02:12:22 -08:00
Tri Dao
b8ccd20098 [Triton] Fix variable name from qkv to kv (h/t FrankZijlstra) 2022-11-22 02:07:32 -08:00
Tri Dao
054816177e Bump version to 0.2.1 2022-11-20 22:35:59 -08:00
Tri Dao
0fa5c0d7ef Add PatchEmbed 2022-11-17 16:56:06 -08:00
Tri Dao
ece539abd6 Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
Tri Dao
39ed597b28 [LayerNorm] Compile for both sm70 and sm80 2022-11-17 11:45:11 -08:00
Tri Dao
71f674ae23 [Rotary] Customize base, support seqlen_offset 2022-11-17 11:43:36 -08:00
Tri Dao
d6ef701aa9 Set version to 0.2.0 (instead of 0.2) 2022-11-15 14:15:05 -08:00
Tri Dao
4040256b5e Update pip install instructions, bump to 0.2 2022-11-15 14:10:48 -08:00
Tri Dao
56aa49037d
Merge pull request #75 from lucidrains/main
allow for uploading to pypi
2022-11-15 13:55:59 -08:00
Phil Wang
dcf3986590 update manifest 2022-11-15 13:38:13 -08:00
Phil Wang
b0eac3297f allow for uploading to pypi 2022-11-15 13:26:55 -08:00
Tri Dao
43ab0b5205 Mention that some CUDA extensions have only been tested on A100s 2022-11-15 07:10:25 -08:00