Tri Dao
e0fbaa7016
[Gen] Simplify decode_speculative
2023-09-19 22:20:22 -07:00
Tri Dao
e6a8026489
[Gen] Rename max_sequence_len->max_seqlen, sequence_len_offset->seqlen_offset
2023-09-19 22:20:22 -07:00
Kevin Hu
42832575d4
Fix Llama GQA/MQA ( #546 )
...
* Fix llama MQA
* Fix permute shape
* Update llama.py
2023-09-19 22:15:59 -07:00
Tri Dao
dfe29f5e2b
[Gen] Don't use ft_attention, use flash_attn_with_kvcache instead
2023-09-18 15:29:06 -07:00
Tri Dao
799f56fa90
Don't compile for Pytorch 2.1 on CUDA 12.1 due to nvcc segfaults
2023-09-17 22:15:38 -07:00
Tri Dao
c984208ddb
Set block size to 64 x 64 for kvcache to avoid nvcc segfaults
2023-09-17 16:14:58 -07:00
Tri Dao
8c8b4d36e1
Bump to v2.2.3
2023-09-16 01:47:01 -07:00
Tri Dao
ccbb14f38e
Implement rotary embedding in flash_attn_with_kvcache
2023-09-16 01:20:16 -07:00
Tri Dao
5400fdc4ac
[CE] Implement CrossEntropyLoss in Triton
2023-09-15 20:05:28 -07:00
Tri Dao
d0032700d1
Add tests for Pythia, GPT-JT, and RedPajama models
2023-09-13 01:10:39 -07:00
Tri Dao
08c295c043
Bump to v2.2.2
2023-09-10 23:48:12 -07:00
Tri Dao
ee77b931b9
Swap seqlen_q and nheads for MQA to speed it up (h/t Daniel Haziza)
2023-09-10 22:56:33 -07:00
Kevin Hu
07005806ff
Add BigCode converters ( #532 )
2023-09-10 17:24:50 -07:00
Tri Dao
8a733cbd53
[Gen] Fix calling update_graph_cache in tests
2023-09-10 17:22:37 -07:00
Kevin Hu
4c91621a5e
Inverse state dict for BERT ( #527 )
2023-09-09 01:44:21 -07:00
Tri Dao
a86442f0f3
[Gen] Use flash_attn_with_kvcache in generation
2023-09-07 08:24:43 -07:00
Tri Dao
a1576ad1e8
Bump to v2.2.1
2023-09-06 02:19:55 -07:00
Tri Dao
9795159082
[Rotary] Set device before launching Triton kernel to avoid error
2023-09-05 21:29:03 -07:00
Tri Dao
6d673cd961
Bump to v2.2.0
2023-09-05 11:34:13 -07:00
Kyeongpil Kang
8e893f0950
Create __init__.py for ops/triton dir ( #516 )
2023-09-05 11:29:03 -07:00
Tri Dao
fd20f16a4e
Support cache_seqlens being integer
2023-09-05 11:27:48 -07:00
Tri Dao
913922cac5
[Gen] Refactor decoding function
2023-09-04 17:01:38 -07:00
Tri Dao
3557e0bb8f
[MLP] Implement SwiGLU with torch jiterator
2023-09-04 15:43:53 -07:00
Tri Dao
37c6e05406
Implement flash_attn_with_kvcache
2023-09-04 00:11:44 -07:00
Tri Dao
4976650f74
Set single threaded compilation for CUDA 12.2 so CI doesn't OOM
2023-09-03 23:42:55 -07:00
Tri Dao
6a89b2f121
Remove constexpr in launch template to fix CI compilation
2023-09-03 22:59:41 -07:00
Tri Dao
97ba7a62e9
Try switching back to Cutlass 3.2.0
2023-09-03 22:45:35 -07:00
Tri Dao
1dc1b6c8f2
Bump to v2.1.2
2023-09-03 22:23:05 -07:00
Tri Dao
798858f9f1
Fix test_baichuan
2023-09-03 21:01:37 -07:00
Tri Dao
7b33743a72
[Gen] Add back num_last_tokens in gpt.py
2023-09-03 20:44:40 -07:00
Tri Dao
b28ec236df
[Rotary] Implement varlen rotary
2023-09-03 17:57:10 -07:00
Tri Dao
861c82577d
[Rotary] Clean up rotary Triton implementation a bit
2023-09-03 16:41:17 -07:00
Tri Dao
1c523c1ce1
[Rotary] Speed up rotary kernel when interleaved=True
2023-09-03 16:24:37 -07:00
Tri Dao
de2949f37d
[Rotary] Pass max_seqlen from mha.py to rotary during inference
2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046
[Rotary] Implement rotary in Triton
2023-09-03 02:51:58 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa ( #491 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding ( #490 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 ( #502 )
...
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
8a326bbc9e
[Gen] Minor fix to modify logits for top_p
2023-08-29 14:29:06 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences ( #499 )
...
* add unpad_input_for_concatenated_sequences
* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
757058d4d3
Update Cutlass to v3.2.0
2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a
[Gen] Clone logits before returning when cg=True
2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0
[GPT] Generalize last_token_only arg to num_last_tokens
2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a
[Gen] Fix decode function not using top_p during iterative decoding
2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c
[Gen] Refactor decode function a bit
2023-08-26 14:47:25 -07:00
Tri Dao
a2974e850a
Change causal for CrossAttention in mha.py to align to bottom right
2023-08-26 12:57:33 -07:00
Tri Dao
73bd3f3bbb
Move pyproject.toml to flash-attn and tests dir to avoid PEP 517
2023-08-25 15:05:28 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained ( #479 )
2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40
Bump version to 2.0.9
2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9
[GPT] Fix loading weights from HF hub
2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B ( #425 )
2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 ( #468 )
...
* q
* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751
Import torch before flash_attn_2_cuda
2023-08-19 21:07:33 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format ( #465 )
...
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 ( #461 )
...
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70
[ViT] Run black on vit.py
2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1
[ViT] Minor fix so it runs
2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d
[GPT] Run black on gpt.py
2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374
[MHA] Run black on mha.py
2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training ( #446 )
...
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
c60851a825
Bump to v2.0.7
2023-08-14 14:55:35 -07:00
Tri Dao
f8dccfc90a
[CI] Fix MATRIX_CUDA_VERSION check
2023-08-14 10:27:26 -07:00
Tri Dao
9c531bdc0a
Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI
2023-08-14 10:03:31 -07:00
Tri Dao
67ae6fd74b
Bump to v2.0.6
2023-08-13 16:52:48 -07:00
Tri Dao
c5e87b11e9
Bump to v2.0.5
2023-08-13 13:55:04 -07:00
Tri Dao
364a5b4a71
[MLP] Change the check for out_features being None
2023-08-10 00:04:38 -07:00
Tri Dao
d30f2e1cd5
Bump to v2.0.4
2023-08-01 09:01:07 -07:00
Tri Dao
a4e5d1eddd
Bump to v2.0.3
2023-07-31 17:49:23 -07:00
Tri Dao
8f4cd4c16b
[Docs] Fix docstring about Q nheads being divisible by KV nheads
2023-07-31 17:47:03 -07:00
Tri Dao
184b992dcb
[GPT] Implement parallel LLaMa
2023-07-28 15:52:48 -10:00
Tri Dao
840f7925a0
[Docs] Fix mention of MQA/GQA in qkvpacked functions
2023-07-28 12:26:29 -10:00
Tri Dao
60499abcfd
[Benchmark] Add script to benchmark FlashAttention
2023-07-28 00:26:52 -10:00
Kirthi Shankar Sivamani
32a953f486
Request for v2.0.2 ( #388 )
...
* Bump version to 2.0.2
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Update version in Dockerfile
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-28 02:46:03 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4c98d0b41f
[MLP] Edit ParallelGatedMlp
2023-07-26 09:39:37 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp ( #251 )
2023-07-26 12:14:15 -07:00
Tri Dao
b252072409
Bump to v2.0.1
2023-07-23 12:33:42 -10:00
Tri Dao
d38357dd2f
[GPT] Implement Falcon
2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert ( #363 )
2023-07-23 00:21:45 -07:00
Tri Dao
425dbcb6c6
[MHA] Implement MQA/GQA
2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a
[Rotary] Don't store inv_freq in state_dict
2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407
[MLP] Add ParallelMLP
2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6
[GPT] Enable FlashAttention for GPT-J
2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2
[Block] Re-enable DropPath
2023-07-21 16:39:23 -07:00
Tri Dao
b4cc152e97
Make sure dout is contiguous
2023-07-17 21:54:44 -07:00
Tri Dao
4f285b3547
FlashAttention-2 release
2023-07-17 06:21:34 -07:00
Tri Dao
6d48e14a6c
Bump to v1.0.9
2023-07-17 03:16:40 -07:00
Volodymyr Kyrylov
70ab266a56
rotary: update cos/sin cache when switching from inference mode
...
This resolves RuntimeErrors after running evaluation in inference mode:
```
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/modules/mha.py", line 492, in forward
qkv = self.rotary_emb(qkv)
File "/home/proger/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/proger/.local/lib/python3.10/site-packages/flash_attn/layers/rotary.py", line 229, in forward
return apply_rotary_emb_qkv_(
File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.
```
2023-07-08 12:01:07 +02:00
Tri Dao
d2f4324f4c
[LayerNorm] Make sure memory addresses are aligned to 16 bytes
2023-07-04 14:53:12 -07:00
Tri Dao
e8a0b4acdd
[Doc] Change total -> total_q
2023-07-02 17:23:52 -07:00
Tri Dao
9610114ce8
Bump to v1.0.8
2023-07-02 17:04:54 -07:00
Tri Dao
62e9814466
[Rotary] Make sure frequency calculation is in fp32
2023-07-02 16:39:39 -07:00
ljss
8e44c0eefb
Fix a bug
2023-06-02 13:46:19 +08:00
Tri Dao
85b51d61ee
Bump version to 1.0.7
2023-05-30 14:18:44 -07:00
Tri Dao
48bc6eacd6
[Gen] Add rotary base as an argument to FT attention kernel
2023-05-30 13:38:34 -07:00
Kirthi Shankar Sivamani
dd9c3a1fc2
bump to v1.0.6
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-05-26 17:44:10 -07:00
Max H. Gerlach
31f78a9814
Allow adding an optional local version to the package version
2023-05-19 17:27:41 +02:00
Federico Berto
69f5f7d0a2
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:44 +09:00
Federico Berto
3889ba168b
[BugFix] cannot unpack non-iterable NoneType object
2023-05-07 03:07:30 +09:00
Tri Dao
a9a4b4e4f2
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
2023-05-04 23:39:43 -07:00
Tri Dao
fcab93b43a
[Gen] Minor tweak to allocate_inference_cache
2023-04-21 11:56:47 -07:00
Tri Dao
ba2fe7f378
[Gen] Move allocate_inference_cache to within the model
2023-04-20 18:15:12 -07:00
Tri Dao
3da42d24b1
[GPT] Add option to only return the logit for the last token
2023-04-20 17:21:08 -07:00
Tri Dao
311d6606bf
[Gen] Fix FT kernel smem size, CG when batch size changed
2023-04-20 17:03:13 -07:00
Tri Dao
96d10f6545
Implement LLaMa
2023-04-18 21:51:35 -07:00
Tri Dao
b630aef53f
Implement GatedMlp
2023-04-18 03:37:14 -07:00
Tri Dao
ac3b684cdb
Have a separate nn.Dropout module in SelfAttention module
2023-04-17 22:34:05 -07:00
Kirthi Shankar Sivamani
a0997bc77c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-14 21:45:37 -07:00
Tri Dao
605655bc66
[Gen] Fix FT kernel when using CG
2023-04-14 16:50:01 -07:00
Kirthi Shankar Sivamani
081c2b012a
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-13 19:36:45 -07:00
Tri Dao
1c9ef9b399
[Gen] Measure prompt processing + decoding time, not just decoding
2023-04-13 15:39:56 -07:00
Tri Dao
6f6e9a9aaf
[FusedDense] Enable sqrelu activation in FusedMLP
2023-04-13 15:29:32 -07:00
Kirthi Shankar Sivamani
7d25a4ec4f
Handle FlashAttnQKVPackedSplitFunc by making rng_state optional in backward
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-13 06:25:52 +00:00
Kirthi Shankar Sivamani
315fd31f0c
Merge branch 'HazyResearch:main' into enable_cuda_graph_capture
2023-04-12 22:42:24 -07:00
Zhiyuan Chen
8c42415664
make mlp hidden_features defaults to 4*in_features
2023-04-13 11:08:21 +08:00
Kirthi Shankar Sivamani
31018c5fa0
Support CUDA graph capture
...
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-04-12 16:53:22 -07:00
Tri Dao
d6fc860573
Merge pull request #147 from ksivaman/add_deterministic_execution_option
...
Add option for deterministic execution
2023-03-31 17:32:50 -04:00
Tri Dao
393882bc08
[LayerNorm] Implement LN with parallel residual, support dim 8k
2023-03-31 14:23:45 -07:00
Kirthi Shankar Sivamani
b6aa059bbf
Add option for deterministic execution
2023-03-30 18:23:35 -07:00
Tri Dao
993d12448e
Implement GPT-NeoX
2023-03-29 01:21:25 -07:00
Tri Dao
f5d0fbd468
[FT] Fix FT's single query attention for bf16 hdim128 rotary
2023-03-28 21:27:00 -07:00
Tri Dao
4d87e4d875
Implement GPT-J
2023-03-22 16:16:58 -07:00
Tri Dao
5d079fdd7a
[Triton] Fix benchmark_causal, mention Triton version
2023-03-22 00:51:16 -07:00
Tri Dao
dc08ea1c33
Support H100 for other CUDA extensions
2023-03-15 16:59:27 -07:00
Vik Paruchuri
3165398074
Remove unused kwargs in flashattention
2023-03-15 10:36:19 -07:00
Tri Dao
e45a46a5b7
[Rotary] Implement GPT-J style (interleaved) rotary
2023-03-14 14:35:53 -07:00
Tri Dao
78b7a1dc18
[OPT] Load fp16 weights on CPU before moving to GPU
2023-01-22 17:01:32 -08:00
Tri Dao
eb33e587e9
[LayerNorm] Rename x1 -> residual
2023-01-19 13:07:27 -08:00
Tri Dao
f68d41ec77
[Gen] Add OPT to generation test
2023-01-17 19:59:06 -08:00
Tri Dao
88173a1aaf
[FusedDense] Support relu, rename FusedDenseGeluDense -> FusedMLP
2023-01-17 18:12:27 -08:00
Tri Dao
780e8eeabb
[ViT] Support timm checkpoint, add tests
2023-01-16 01:20:34 -08:00
Tri Dao
2ec7d3f72c
Merge pull request #105 from jamaliki/patch-1
...
Change default dropout value in documentation
2023-01-15 23:01:20 -08:00
Tri Dao
ef085cfcda
[ViT] Fix extra norm_0, use new LN order in Block
2023-01-15 22:58:56 -08:00
Tri Dao
ff34123bd4
Reorder LN in Block, support OPT
2023-01-15 22:14:31 -08:00
Tri Dao
7c2191542a
[Gen] Make generation work with Tensor Parallel
2023-01-15 11:34:27 -08:00
Kiarash Jamali
41cb909741
Change default dropout value in documentation
...
Documentation says default is 0.1, but the code has attention_dropout default at 0.0
2023-01-13 10:50:07 +00:00
Tri Dao
f95c2fc108
[Gen] Remove commented code
2023-01-07 19:06:39 -08:00
Tri Dao
b48599002a
[Gen] Add timing option
2023-01-07 19:05:09 -08:00
Tri Dao
0938298e4c
[Gen] Adjust shape of kv_cache when using FT
2023-01-07 17:27:54 -08:00
Tri Dao
e02fd588aa
[Gen] Implement top-k and top-p sampling
2023-01-07 17:00:02 -08:00
Tri Dao
11be742aa3
[Gen] Test generation with rotary embedding
2023-01-07 14:37:54 -08:00
Tri Dao
8d9674ed08
Merge pull request #102 from Lamikins/main
...
fixed cross attention typeerror
2023-01-07 13:56:20 -08:00
Tri Dao
93383bd55b
[TP] Implement TensorParallel without sequence parallel
2023-01-07 13:45:22 -08:00
Darius Lam
aec35fd67c
fixed cross attention typeerror
2023-01-07 12:58:41 -08:00
Tri Dao
6738d9477d
[LayerNorm] Implement RMS Norm
2023-01-06 17:34:22 -08:00