Commit Graph

631 Commits

Author SHA1 Message Date
Tri Dao
26d7d92f3d Fix splitKV combine function when local LSEs are all -inf 2023-09-03 11:39:09 -07:00
Tri Dao
de2949f37d [Rotary] Pass max_seqlen from mha.py to rotary during inference 2023-09-03 11:37:06 -07:00
Tri Dao
942fcbf046 [Rotary] Implement rotary in Triton 2023-09-03 02:51:58 -07:00
Tri Dao
08e9847176 [CI] Add CUDA 12.2 2023-09-03 02:45:42 -07:00
Sophia Wisdom
37e32febba
Remove commented out code in bwd (#512)
* Remove lots of comments

* Remove unused traits
2023-09-01 16:43:58 -07:00
Sophia Wisdom
dd8a754915
Remove old code in utils.h (#511) 2023-09-01 15:32:09 -07:00
Aman Gupta Karmani
866a9d33f9
bump cutlass submodule (#504) 2023-08-30 10:32:04 -07:00
dan_the_3rd
c9d4a816fa
Support LLaMa2 and CodeLLaMa (#491)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:31:14 -07:00
dan_the_3rd
011ec323d6
Support MQA + MP for decoding (#490)
Co-authored-by: danthe3rd <danthe3rd>
2023-08-30 10:29:54 -07:00
GAOXinyu
0cb595ad94
[bugfix] handle_x not define when using checkpoint_lvl = 2 (#502)
when using checkpoint_lvl=2, we all_gather_raw(x) without async_op=True.
So we don't need to wait for handle. Just skip.
2023-08-29 23:46:10 -07:00
Tri Dao
31920dda5f Fix typo with lse_max == -INFINITY 2023-08-29 21:48:59 -07:00
Tri Dao
8a326bbc9e [Gen] Minor fix to modify logits for top_p 2023-08-29 14:29:06 -07:00
Jeffrey Quesnelle
1d817a8ffc
fix citation in README (#501) 2023-08-29 11:15:33 -07:00
Su Zhu
8f6f48d8a8
add unpad_input_for_concatenated_sequences (#499)
* add unpad_input_for_concatenated_sequences

* modify docstring
2023-08-29 02:23:56 -07:00
Tri Dao
b1fbbd8337 Implement splitKV attention 2023-08-29 00:58:29 -07:00
Tri Dao
7a983df742 Use generate_kernels.py script from Driss Guessous 2023-08-28 13:34:12 -07:00
dan_the_3rd
c3f2a632aa
[ft_attention] Fix for seqlen=8136 (#488)
When seqlen=8136, `smem_sz = 48840`, and apparently starting the kernel returns an `invalid argument` CUDA error.

`48840 < 48 * 1024` but apparently it's still above the limit somehow..?
Tested on A100
2023-08-28 10:00:22 -07:00
Tri Dao
757058d4d3 Update Cutlass to v3.2.0 2023-08-27 23:47:28 -07:00
Tri Dao
9f42cb6e7a [Gen] Clone logits before returning when cg=True 2023-08-27 23:19:58 -07:00
Tri Dao
f8aea6ead0 [GPT] Generalize last_token_only arg to num_last_tokens 2023-08-26 20:47:53 -07:00
Tri Dao
7a3bd55f1a [Gen] Fix decode function not using top_p during iterative decoding 2023-08-26 15:14:41 -07:00
Tri Dao
847abe653c [Gen] Refactor decode function a bit 2023-08-26 14:47:25 -07:00
Tri Dao
371e20658c [GPT] Test generation when passing in multiple tokens 2023-08-26 13:56:41 -07:00
Tri Dao
c000c3a2c0 [GPT] Move more tests to test_gpt.py 2023-08-26 13:00:40 -07:00
Tri Dao
a2974e850a Change causal for CrossAttention in mha.py to align to bottom right 2023-08-26 12:57:33 -07:00
Tri Dao
9b713872ea [GPT] Move GPT and OPT generation tests to test_{gpt,opt}.py 2023-08-26 12:55:02 -07:00
Tri Dao
73bd3f3bbb Move pyproject.toml to flash-attn and tests dir to avoid PEP 517 2023-08-25 15:05:28 -07:00
Aman Gupta Karmani
b4b6e90334
add benchmark for xformers fa2 wrapper (#492) 2023-08-25 14:10:05 -07:00
Tri Dao
45ba93cd96 Add newlines to README 2023-08-24 23:54:13 -07:00
Tri Dao
9e5e8bc91e Change causal mask to be aligned to bottom-right instead of top-left 2023-08-24 23:41:07 -07:00
BoxiangW
e07aa036db
Support flash attention 2 with causal masking when KV's seq length is longer than Q's seq length. (#436) 2023-08-24 16:42:34 -07:00
Aman Gupta Karmani
e0b09891c6
add llama support to GPTPreTrainedModel.from_pretrained (#479) 2023-08-24 16:31:16 -07:00
Tri Dao
6711b3bc40 Bump version to 2.0.9 2023-08-22 00:21:14 -07:00
Tri Dao
ef6d8c75d9 [GPT] Fix loading weights from HF hub 2023-08-21 22:56:02 -07:00
GAOXinyu
a8c35b4f57
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425) 2023-08-21 11:05:06 -07:00
Xuechen Li
25d6b1dbcb
handle uneven heads across ranks when combining state_dicts; resolves #467 (#468)
* q

* add comment.
2023-08-20 14:57:34 -07:00
Tri Dao
d431f16751 Import torch before flash_attn_2_cuda 2023-08-19 21:07:33 -07:00
Tri Dao
0e8c46ae08 Run isort and black on test files 2023-08-18 20:59:35 -07:00
Xuechen Li
7fcd3e6a04
map custom model state_dict back to huggingface format (#465)
* fix name.

* set inv function.

* add map back function.

* handle gqa.

* add type annotation to avoid confusion.

* fix docstr.

* test inverse remap logic.
2023-08-18 20:51:39 -07:00
Tri Dao
f1a73d0740 Run isort and black on python files 2023-08-18 14:22:11 -07:00
Tri Dao
cbb4cf5f46 Don't need to set TORCH_CUDA_ARCH_LIST in setup.py 2023-08-18 14:18:54 -07:00
Xuechen Li
bb4cded17b
support when num_heads is not divisible by world_size; resolves #459 (#461)
* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
2023-08-18 14:10:35 -07:00
Tri Dao
ada4710d70 [ViT] Run black on vit.py 2023-08-17 17:45:09 -07:00
Tri Dao
a81900d4c1 [ViT] Minor fix so it runs 2023-08-17 17:25:34 -07:00
Tri Dao
4b661a569d [GPT] Run black on gpt.py 2023-08-16 23:47:50 -07:00
Tri Dao
bec5b3d374 [MHA] Run black on mha.py 2023-08-16 23:47:13 -07:00
Tri Dao
cb0daccc41 [FusedDense] Allow Row/ColumnParallelLinear to have uneven split 2023-08-16 23:43:35 -07:00
Tri Dao
bcfa7c9751 [FusedDense] Run black on fused_dense.py 2023-08-16 23:41:36 -07:00
Tri Dao
2286d7cea7 Bump to v2.0.8 2023-08-16 15:13:12 -07:00
Tri Dao
c65b5106ac Fix Bwd NaN for varlen when seqlen_q >> seqlen_k and causal 2023-08-16 15:12:36 -07:00