Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c
Bump to v1.0.9
2023-07-17 03:16:40 -07:00
Volodymyr Kyrylov
70ab266a56
rotary: update cos/sin cache when switching from inference mode
...
This resolves RuntimeErrors after running evaluation in inference mode:
```
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
qkv = self.rotary_emb(qkv)
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
return apply_rotary_emb_qkv_(
File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
d2f4324f4c
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
2023-07-04 14:53:12 -07:00
Tri Dao
e8a0b4acdd
[Doc] Change total -> total_q
2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8
Bump to v1.0.8
2023-07-02 17:04:54 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
ljss
8e44c0eefb
Fix a bug
2023-06-02 13:46:19 +08:00
Tri Dao
85b51d61ee
Bump version to 1.0.7
2023-05-30 14:18:44 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Kirthi Shankar Sivamani
dd9c3a1fc2
bump to v1.0.6
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-26 17:44:10 -07:00
Max H. Gerlach
31f78a9814
Allow adding an optional local version to the package version
2023-05-19 17:27:41 +02:00
Federico Berto
69f5f7d0a2
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:44 +09:00
Federico Berto
3889ba168b
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:30 +09:00
Tri Dao
a9a4b4e4f2
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
2023-05-04 23:39:43 -07:00
Tri Dao
fcab93b43a
[Gen] Minor tweak to allocate_inference_cache
2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378
[Gen] Move allocate_inference_cache to within the model
2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1
[GPT] Add option to only return the logit for the last token
2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf
[Gen] Fix FT kernel smem size, CG when batch size changed
2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f
Implement GatedMlp
2023-04-18 03:37:14 -07:00
Tri Dao
ac3b684cdb
Have a separate nn.Dropout module in SelfAttention module
2023-04-17 22:34:05 -07:00
Kirthi Shankar Sivamani
a0997bc77c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-14 21:45:37 -07:00
Tri Dao
605655bc66
[Gen] Fix FT kernel when using CG
2023-04-14 16:50:01 -07:00
Kirthi Shankar Sivamani
081c2b012a
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-13 19:36:45 -07:00
Tri Dao
1c9ef9b399
[Gen] Measure prompt processing + decoding time, not just decoding
2023-04-13 15:39:56 -07:00
Tri Dao
6f6e9a9aaf
[FusedDense] Enable sqrelu activation in FusedMLP
2023-04-13 15:29:32 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f
Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-12 22:42:24 -07:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features
2023-04-13 11:08:21 +08:00
Kirthi Shankar Sivamani
31018c5fa0
Support CUDA graph capture
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
d6fc860573
Merge pull request #147 from ksivaman/add_deterministic_execution_option
...
Add option for deterministic execution
2023-03-31 17:32:50 -04:00
Tri Dao
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
2023-03-31 14:23:45 -07:00
Kirthi Shankar Sivamani
b6aa059bbf
Add option for deterministic execution
2023-03-30 18:23:35 -07:00
Tri Dao
993d12448e
Implement GPT-NeoX
2023-03-29 01:21:25 -07:00
Tri Dao
f5d0fbd468
[FT] Fix FT's single query attention for bf16 hdim128 rotary
2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875
Implement GPT-J
2023-03-22 16:16:58 -07:00
Tri Dao
5d079fdd7a
[Triton] Fix benchmark_causal, mention Triton version
2023-03-22 00:51:16 -07:00
Tri Dao
dc08ea1c33
Support H100 for other CUDA extensions
2023-03-15 16:59:27 -07:00
Vik Paruchuri
3165398074
Remove unused kwargs in flashattention
2023-03-15 10:36:19 -07:00
Tri Dao
e45a46a5b7
[Rotary] Implement GPT-J style (interleaved) rotary
2023-03-14 14:35:53 -07:00
Tri Dao
78b7a1dc18
[OPT] Load fp16 weights on CPU before moving to GPU
2023-01-22 17:01:32 -08:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
f68d41ec77
[Gen] Add OPT to generation test
2023-01-17 19:59:06 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb
[ViT] Support timm checkpoint, add tests
2023-01-16 01:20:34 -08:00
Tri Dao
2ec7d3f72c
Merge pull request #105 from jamaliki/patch-1
...
Change default dropout value in documentation
2023-01-15 23:01:20 -08:00
Tri Dao
ef085cfcda
[ViT] Fix extra norm_0, use new LN order in Block
2023-01-15 22:58:56 -08:00
Tri Dao
ff34123bd4
Reorder LN in Block, support OPT
2023-01-15 22:14:31 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Kiarash Jamali
41cb909741
Change default dropout value in documentation
...
Documentation says default is 0.1, but the code has attention_dropout default at 0.0
2023-01-13 10:50:07 +00:00
Tri Dao
f95c2fc108
[Gen] Remove commented code
2023-01-07 19:06:39 -08:00
Tri Dao
b48599002a
[Gen] Add timing option
2023-01-07 19:05:09 -08:00
Tri Dao
0938298e4c
[Gen] Adjust shape of kv_cache when using FT
2023-01-07 17:27:54 -08:00
Tri Dao
e02fd588aa
[Gen] Implement top-k and top-p sampling
2023-01-07 17:00:02 -08:00
Tri Dao
11be742aa3
[Gen] Test generation with rotary embedding
2023-01-07 14:37:54 -08:00
Tri Dao
8d9674ed08
Merge pull request #102 from Lamikins/main
...
fixed cross attention typeerror
2023-01-07 13:56:20 -08:00
Tri Dao
93383bd55b
[TP] Implement TensorParallel without sequence parallel
2023-01-07 13:45:22 -08:00
Darius Lam
aec35fd67c
fixed cross attention typeerror
2023-01-07 12:58:41 -08:00
Tri Dao
6738d9477d
[LayerNorm] Implement RMS Norm
2023-01-06 17:34:22 -08:00
Tri Dao
a668890fcd
[Gen] Add option to run generation with FT attention kernel
2023-01-03 22:10:31 -08:00
Tri Dao
4cab4de5ea
[TP] Put parallel embeddings in separate modules
2023-01-02 08:47:48 -08:00
Tri Dao
1ec09ebd90
[FusedDense] Limit matrix dims to 2M (instead of 64k)
2023-01-01 17:06:39 -08:00
Tri Dao
714c1b4f0f
[Bert] Fix embedding layer norm before embedding dropout
2023-01-01 10:38:05 -08:00
Tri Dao
ef1ba918c6
[GPT] Refactor function to shard state_dict for TensorParallel
2023-01-01 00:09:33 -08:00
Tri Dao
65b4064b2a
[FusedDense] Kick off input all_gather before weight dtype conversion
2022-12-31 22:47:34 -08:00
Tri Dao
85b8e3d334
[Docs] Mention that XPos's scale_base is recommended to be 512
2022-12-29 20:25:02 -08:00
Tri Dao
a6ec1782dc
Bump to v0.2.6
2022-12-27 22:05:20 -08:00
Tri Dao
63670fd84a
Implement generation for GPT
2022-12-27 21:01:50 -08:00
Tri Dao
9d797d8848
Support loading GPT2 weights from Huggingface
2022-12-27 11:22:48 -08:00
Tri Dao
c6ecd40a59
Tweak CrossEntropyLoss to take process_group in init
2022-12-27 10:47:43 -08:00
Tri Dao
b4018a5028
Implement Tensor Parallel for GPT model
2022-12-26 16:22:43 -08:00
Tri Dao
78225c5366
Implement Tensor Parallel for GPT2Embeddings
2022-12-25 14:29:53 -08:00
Tri Dao
a8cfe51551
Implement Tensor Parallel for transformer Block
2022-12-25 14:08:21 -08:00
Tri Dao
1e712ea8b0
Implement TensorParallel for MHA
2022-12-25 11:39:55 -08:00
Tri Dao
226a1b721d
Implement TensorParallel for FusedDense and FusedDenseGeluDense
2022-12-24 11:48:56 -08:00
Tri Dao
dff68c2b22
Add smoothing for CrossEntropyParallel, rename to CrossEntropyLoss
2022-12-23 14:51:08 -08:00
Tri Dao
e68ebbe89a
Simplify FusedDense
2022-12-22 21:25:31 -08:00
Tri Dao
496e4f528c
Implement XPos (Sun et al.)
2022-12-21 14:17:58 -08:00
Tri Dao
13cdceb377
Implement last_layer_subset optimization for BERT
2022-12-19 22:18:46 -08:00
Tri Dao
5fb6df0e04
Implement BERT
2022-12-18 21:47:27 -08:00
Alexander Ploshkin
ee8984d2be
add asserts for sin shape
2022-12-17 13:34:57 +04:00
Alexander Ploshkin
c7c66976cc
fix slicing dimensions
2022-12-16 15:39:06 +04:00
Alexander Ploshkin
96656b9323
Remove redundant shape asserts in rotary embeddings
2022-12-15 18:13:21 +04:00
Tri Dao
6b5f271c6d
[Triton] Avoid einops repeat by using Tensor.expand
2022-12-14 14:48:41 -08:00
Tri Dao
88c4e5dbf6
Fix the case when dout is not contiguous
2022-12-13 13:58:17 -08:00
Tri Dao
5db330519a
[LayerNorm] Support taking subset of input or subset of output
2022-12-12 22:16:14 -08:00
Tri Dao
ae137ed17a
[LayerNorm] Fuse LayerScale
2022-12-10 23:28:23 -08:00
Tri Dao
8c6609ae1a
[LayerNorm] Support all dimensions up to 6k (if divisible by 8)
2022-12-09 02:06:22 -08:00
Tri Dao
1feb94265c
[ViT] Use dropout_add_ln for the 1st layer norm
2022-11-23 12:48:56 -08:00
Tri Dao
b8ccd20098
[Triton] Fix variable name from qkv to kv (h/t FrankZijlstra)
2022-11-22 02:07:32 -08:00
Tri Dao
054816177e
Bump version to 0.2.1
2022-11-20 22:35:59 -08:00
Tri Dao
0fa5c0d7ef
Add PatchEmbed
2022-11-17 16:56:06 -08:00
Tri Dao
ece539abd6
Add __init__.py files to subdirectories for installation
2022-11-17 16:55:44 -08:00
Tri Dao
71f674ae23
[Rotary] Customize base, support seqlen_offset
2022-11-17 11:43:36 -08:00
Tri Dao
2e33fc8e36
Add GPT and ViT models
2022-11-13 22:30:23 -08:00
Tri Dao
d4b320b31f
Add MLP, MHA, Block, Embedding modules
2022-11-13 22:06:44 -08:00
Tri Dao
fa6d1ce44f
Add fused_dense and dropout_add_layernorm CUDA extensions
2022-11-13 21:59:20 -08:00
Tri Dao
343492ec30
Make nccl operations async in CrossEntropyLossParallel
2022-11-13 17:27:26 -08:00
Tri Dao
7c9953815a
Add fused cross entropy loss
2022-11-12 21:58:41 -08:00