Commit Graph

553 Commits

Author SHA1 Message Date
Tri Dao
88c4e5dbf6 Fix the case when dout is not contiguous 2022-12-13 13:58:17 -08:00
Tri Dao
a1a5d2ee49 Bump to v0.2.3 2022-12-13 01:37:02 -08:00
Tri Dao
5db330519a [LayerNorm] Support taking subset of input or subset of output 2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a [LayerNorm] Fuse LayerScale 2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) 2022-12-09 02:06:22 -08:00
Tri Dao
8a2ece89f7 Simplify BOOL_SWITCH macro to fix compiling error on gcc 7 2022-12-06 14:38:32 -08:00
Tri Dao
a84d07283c [Docs] Mention FasterTransformer integration 2022-12-05 00:34:09 -08:00
Tri Dao
4a6eaa9f27 Update configs, add results 2022-11-29 04:46:43 -08:00
Tri Dao
0bf5e50038 Release training code 2022-11-28 17:34:40 -08:00
Tri Dao
9bc63d1e2d Fix typo in comments 2022-11-25 16:35:08 -08:00
Tri Dao
d95ee1a95d Speed up compilation by splitting into separate .cu files 2022-11-25 16:30:18 -08:00
Tri Dao
b784ed73cf [Docs] Clarify OpenFold speedup 2022-11-25 10:49:17 -08:00
Tri Dao
d9021ae4ec [Docs] Mention OpenFold 2022-11-23 13:01:19 -08:00
Tri Dao
1feb94265c [ViT] Use dropout_add_ln for the 1st layer norm 2022-11-23 12:48:56 -08:00
Tri Dao
45bcf37b97 [Docs] Capitalize the bibtex citation 2022-11-22 02:12:22 -08:00
Tri Dao
b8ccd20098 [Triton] Fix variable name from qkv to kv (h/t FrankZijlstra) 2022-11-22 02:07:32 -08:00
Tri Dao
054816177e Bump version to 0.2.1 2022-11-20 22:35:59 -08:00
Tri Dao
0fa5c0d7ef Add PatchEmbed 2022-11-17 16:56:06 -08:00
Tri Dao
ece539abd6 Add __init__.py files to subdirectories for installation 2022-11-17 16:55:44 -08:00
Tri Dao
39ed597b28 [LayerNorm] Compile for both sm70 and sm80 2022-11-17 11:45:11 -08:00
Tri Dao
71f674ae23 [Rotary] Customize base, support seqlen_offset 2022-11-17 11:43:36 -08:00
Tri Dao
d6ef701aa9 Set version to 0.2.0 (instead of 0.2) 2022-11-15 14:15:05 -08:00
Tri Dao
4040256b5e Update pip install instructions, bump to 0.2 2022-11-15 14:10:48 -08:00
Tri Dao
56aa49037d
Merge pull request #75 from lucidrains/main
allow for uploading to pypi
2022-11-15 13:55:59 -08:00
Phil Wang
dcf3986590 update manifest 2022-11-15 13:38:13 -08:00
Phil Wang
b0eac3297f allow for uploading to pypi 2022-11-15 13:26:55 -08:00
Tri Dao
43ab0b5205 Mention that some CUDA extensions have only been tested on A100s 2022-11-15 07:10:25 -08:00
Tri Dao
e4d3013e15 [LayerNorm] Check cuda error after querying ctas_per_sm 2022-11-15 07:05:13 -08:00
Tri Dao
b0ed0a73fd Mention DeepSpeed inference in usage.md 2022-11-14 10:01:16 -08:00
Tri Dao
25387b24c1 Mention AITemplate Stable Diffusion in usage.md 2022-11-14 09:41:50 -08:00
Tri Dao
2e33fc8e36 Add GPT and ViT models 2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions 2022-11-13 21:59:20 -08:00
Tri Dao
b92f2c3b67 Link to Colossal-AI's stable diffusion in usage.md 2022-11-13 20:49:05 -08:00
Tri Dao
343492ec30 Make nccl operations async in CrossEntropyLossParallel 2022-11-13 17:27:26 -08:00
Tri Dao
3dda4f76de Update README 2022-11-13 16:52:40 -08:00
Tri Dao
79160a69a9 Add a page on where FlashAttention is being used 2022-11-13 16:40:18 -08:00
Tri Dao
a8fec99a9a Skip flash_attn_split test 2022-11-13 12:27:48 -08:00
Tri Dao
9d3116addf Don't enforce bitwise consistency for dq in race condition test
Since we could be parallelizing over seqlen_k
2022-11-13 12:21:51 -08:00
Tri Dao
7c9953815a Add fused cross entropy loss 2022-11-12 21:58:41 -08:00
Tri Dao
55797f32c9 Remove RotaryEmbedding from FlashAttention module
To avoid import error if one doesn't have rotary_emb installed
2022-11-10 11:54:36 -08:00
Tri Dao
6998e0ecdb Fix out-of-bound memory read 2022-11-09 09:34:14 -08:00
Tri Dao
908a5b2244 Set num_warps=4 for headdim=64 in Triton fw (h/t Michael Benesty) 2022-11-07 08:58:16 -08:00
Tri Dao
7479757191 Fix pipelining bug in Triton bwd with bias_type=matrix 2022-11-06 11:50:35 -08:00
Tri Dao
557781933d Parallelize CUDA bwd along seqlen_k instead of seqlen_q
This is faster since we only need to do atomic adds on dq, instead of atomic
adds on both dk and dv.
2022-11-05 16:26:17 -07:00
Tri Dao
ca81f32e04 Implement rotary embedding in CUDA 2022-11-04 22:42:01 -07:00
Tri Dao
62025e1aff Fix more race condition in Triton bwd when there's bias 2022-11-04 12:53:09 -07:00
Tri Dao
ff78ea4123 Fix race condition in Triton bwd when there's bias 2022-11-04 11:20:27 -07:00
Tri Dao
86862cfd7b Implement attention bias for Triton version 2022-11-04 10:33:54 -07:00
Tri Dao
470010f59b Fix race condition for Triton bwd for headdim 48 and 96 2022-11-03 15:52:40 -07:00