Tri Dao
757058d4d3
Update Cutlass to v3.2.0
2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a
[Gen] Clone logits before returning when cg=True
2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0
[GPT] Generalize last_token_only arg to num_last_tokens
2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a
[Gen] Fix decode function not using top_p during iterative decoding
2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c
[Gen] Refactor decode function a bit
2023-08-26 14:47:25 -07:00
Tri Dao
a2974e850a
Change causal for CrossAttention in mha.py to align to bottom right
2023-08-26 12:57:33 -07:00
Tri Dao
73bd3f3bbb
Move pyproject.toml to flash-attn and tests dir to avoid PEP 517
2023-08-25 15:05:28 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained ( #479 )
2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40
Bump version to 2.0.9
2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9
[GPT] Fix loading weights from HF hub
2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B ( #425 )
2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 ( #468 )
...
* q
* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751
Import torch before flash_attn_2_cuda
2023-08-19 21:07:33 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format ( #465 )
...
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 ( #461 )
...
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70
[ViT] Run black on vit.py
2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1
[ViT] Minor fix so it runs
2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d
[GPT] Run black on gpt.py
2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374
[MHA] Run black on mha.py
2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
c65b5106ac
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training ( #446 )
...
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
c60851a825
Bump to v2.0.7
2023-08-14 14:55:35 -07:00
Tri Dao
f8dccfc90a
[CI] Fix MATRIX_CUDA_VERSION check
2023-08-14 10:27:26 -07:00
Tri Dao
9c531bdc0a
Use single thread compilation for cuda12.1, torch2.1 to avoid OOM CI
2023-08-14 10:03:31 -07:00
Tri Dao
67ae6fd74b
Bump to v2.0.6
2023-08-13 16:52:48 -07:00
Tri Dao
c5e87b11e9
Bump to v2.0.5
2023-08-13 13:55:04 -07:00
Tri Dao
364a5b4a71
[MLP] Change the check for out_features being None
2023-08-10 00:04:38 -07:00
Tri Dao
d30f2e1cd5
Bump to v2.0.4
2023-08-01 09:01:07 -07:00
Tri Dao
a4e5d1eddd
Bump to v2.0.3
2023-07-31 17:49:23 -07:00
Tri Dao
8f4cd4c16b
[Docs] Fix docstring about Q nheads being divisible by KV nheads
2023-07-31 17:47:03 -07:00
Tri Dao
184b992dcb
[GPT] Implement parallel LLaMa
2023-07-28 15:52:48 -10:00
Tri Dao
840f7925a0
[Docs] Fix mention of MQA/GQA in qkvpacked functions
2023-07-28 12:26:29 -10:00
Tri Dao
60499abcfd
[Benchmark] Add script to benchmark FlashAttention
2023-07-28 00:26:52 -10:00
Kirthi Shankar Sivamani
32a953f486
Request for v2.0.2 ( #388 )
...
* Bump version to 2.0.2
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Update version in Dockerfile
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-28 02:46:03 -07:00
Kirthi Shankar Sivamani
a03f6f8e9e
Enable CUDA graphs ( #386 )
...
* Add RNG state to kernel launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Save seed and offset for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Single thread write to global mem
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1colblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* compute_dq_dk_dv_1rowblock get seed and offset from launch params
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change forward c++ APIs to save RNG state for backward
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Change backward c++ APIs to set RNG state for bprop launcher
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fixes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Python side API changes
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Bug fix; only save seeds instead of full offset
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* Account for 3D grid size
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
---------
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
2023-07-27 16:11:34 -07:00
Tri Dao
4c98d0b41f
[MLP] Edit ParallelGatedMlp
2023-07-26 09:39:37 -10:00
Haodong Lyu
8ee62efca3
Implement ParallelGatedMlp ( #251 )
2023-07-26 12:14:15 -07:00
Tri Dao
b252072409
Bump to v2.0.1
2023-07-23 12:33:42 -10:00
Tri Dao
d38357dd2f
[GPT] Implement Falcon
2023-07-23 10:32:29 -07:00
Kiarash Jamali
684196b8c5
Allow rotary embeddings for Bert ( #363 )
2023-07-23 00:21:45 -07:00
Tri Dao
425dbcb6c6
[MHA] Implement MQA/GQA
2023-07-23 00:06:58 -07:00
Tri Dao
ec9f74ab9a
[Rotary] Don't store inv_freq in state_dict
2023-07-22 23:52:42 -07:00
Tri Dao
75e334d407
[MLP] Add ParallelMLP
2023-07-22 23:45:51 -07:00
Tri Dao
b3177dfaf6
[GPT] Enable FlashAttention for GPT-J
2023-07-21 17:29:10 -07:00
Tri Dao
6fc1e07da2
[Block] Re-enable DropPath
2023-07-21 16:39:23 -07:00
Tri Dao
b4cc152e97
Make sure dout is contiguous
2023-07-17 21:54:44 -07:00