Tri Dao
08e9847176
[CI] Add CUDA 12.2
2023-09-03 02:45:42 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd ( #512 )
...
* Remove lots of comments
* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h ( #511 )
2023-09-01 15:32:09 -07:00
Aman Gupta Karmani
866a9d33f9
bump cutlass submodule ( #504 )
2023-08-30 10:32:04 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa ( #491 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding ( #490 )
...
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 ( #502 )
...
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
31920dda5f
Fix typo with lse_max == -INFINITY
2023-08-29 21:48:59 -07:00
Tri Dao
8a326bbc9e
[Gen] Minor fix to modify logits for top_p
2023-08-29 14:29:06 -07:00
Jeffrey Quesnelle
1d817a8ffc
fix citation in README ( #501 )
2023-08-29 11:15:33 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences ( #499 )
...
* add unpad_input_for_concatenated_sequences
* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
b1fbbd8337
Implement splitKV attention
2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742
Use generate_kernels.py script from Driss Guessous
2023-08-28 13:34:12 -07:00
dan_the_3rd
c3f2a632aa
[ft_attention] Fix for seqlen=8136 ( #488 )
...
When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error.
`48840 < 48 * 1024` but apparently it's still above the limit somehow..?
Tested on A100
2023-08-28 10:00:22 -07:00
Tri Dao
757058d4d3
Update Cutlass to v3.2.0
2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a
[Gen] Clone logits before returning when cg=True
2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0
[GPT] Generalize last_token_only arg to num_last_tokens
2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a
[Gen] Fix decode function not using top_p during iterative decoding
2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c
[Gen] Refactor decode function a bit
2023-08-26 14:47:25 -07:00
Tri Dao
371e20658c
[GPT] Test generation when passing in multiple tokens
2023-08-26 13:56:41 -07:00
Tri Dao
c000c3a2c0
[GPT] Move more tests to test_gpt.py
2023-08-26 13:00:40 -07:00
Tri Dao
a2974e850a
Change causal for CrossAttention in mha.py to align to bottom right
2023-08-26 12:57:33 -07:00
Tri Dao
9b713872ea
[GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py
2023-08-26 12:55:02 -07:00
Tri Dao
73bd3f3bbb
Move pyproject.toml to flash-attn and tests dir to avoid PEP 517
2023-08-25 15:05:28 -07:00
Aman Gupta Karmani
b4b6e90334
add benchmark for xformers fa2 wrapper ( #492 )
2023-08-25 14:10:05 -07:00
Tri Dao
45ba93cd96
Add newlines to README
2023-08-24 23:54:13 -07:00
Tri Dao
9e5e8bc91e
Change causal mask to be aligned to bottom-right instead of top-left
2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. ( #436 )
2023-08-24 16:42:34 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained ( #479 )
2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40
Bump version to 2.0.9
2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9
[GPT] Fix loading weights from HF hub
2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B ( #425 )
2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 ( #468 )
...
* q
* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751
Import torch before flash_attn_2_cuda
2023-08-19 21:07:33 -07:00
Tri Dao
0e8c46ae08
Run isort and black on test files
2023-08-18 20:59:35 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format ( #465 )
...
* fix name.
* set inv function.
* add map back function.
* handle gqa.
* add type annotation to avoid confusion.
* fix docstr.
* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740
Run isort and black on python files
2023-08-18 14:22:11 -07:00
Tri Dao
cbb4cf5f46
Don't need to set TORCH_CUDA_ARCH_LIST in setup.py
2023-08-18 14:18:54 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 ( #461 )
...
* uneql rank.
* trim.
* enable passing in number of heads for each rank.
* simplify.
* simplify.
* cleanup.
* fix col parallel.
* fix bug with row parallel.
* fit out proj.
* refac.
* fix sharding logic.
* refac sharding.
* refac.
* support multiple of.
* make fn reuseable.
* fix bug in dimensions.
* scaffold.
* test uneven heads.
* fix test by adding barrier.
* refac.
* reuse code.
* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70
[ViT] Run black on vit.py
2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1
[ViT] Minor fix so it runs
2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d
[GPT] Run black on gpt.py
2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374
[MHA] Run black on mha.py
2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41
[FusedDense] Allow Row/ColumnParallelLinear to have uneven split
2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751
[FusedDense] Run black on fused_dense.py
2023-08-16 23:41:36 -07:00
Tri Dao
2286d7cea7
Bump to v2.0.8
2023-08-16 15:13:12 -07:00
Tri Dao
c65b5106ac
Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal
2023-08-16 15:12:36 -07:00
Xuechen Li
0f7853c6a1
enable loading hf llama checkpoints for training ( #446 )
...
* prelim.
* add hf convertion fn.
* mlp.
* change name.
* fix bug.
* inverse permute.
* change comment.
* revert style changes.
* fix.
* add doc.
* revert.
* enable load safe.
* fix safe load.
* fix import.
* fix typing-related lints.
* fix ckpt loading logic.
* make single gpu work.
* test with parallel.
* ckpt format.
* enable pretrained state dict.
* remove unused imports.
* remove unused.
* mark idea related.
2023-08-15 08:33:15 -07:00
Tri Dao
c60851a825
Bump to v2.0.7
2023-08-14 14:55:35 -07:00
Aman Gupta Karmani
aab603af4f
fix binary wheel installation when nvcc is not available ( #448 )
2023-08-14 14:54:26 -07:00