Commit Graph

13 Commits

Author SHA1 Message Date
Tri Dao
dfe29f5e2b [Gen] Don't use ft_attention, use flash_attn_with_kvcache instead 2023-09-18 15:29:06 -07:00
Tri Dao
d0032700d1 Add tests for Pythia, GPT-JT, and RedPajama models 2023-09-13 01:10:39 -07:00
Tri Dao
8a733cbd53 [Gen] Fix calling update_graph_cache in tests 2023-09-10 17:22:37 -07:00
Tri Dao
913922cac5 [Gen] Refactor decoding function 2023-09-04 17:01:38 -07:00
Tri Dao
0e8c46ae08 Run isort and black on test files 2023-08-18 20:59:35 -07:00
Tri Dao
8e9820a55b [Rotary] Fix tests when loading state dict with rotary inv_freqs 2023-07-26 07:16:33 -10:00
Tri Dao
425dbcb6c6 [MHA] Implement MQA/GQA 2023-07-23 00:06:58 -07:00
Tri Dao
b3177dfaf6 [GPT] Enable FlashAttention for GPT-J 2023-07-21 17:29:10 -07:00
Tri Dao
96d10f6545 Implement LLaMa 2023-04-18 21:51:35 -07:00
Tri Dao
605655bc66 [Gen] Fix FT kernel when using CG 2023-04-14 16:50:01 -07:00
Tri Dao
393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k 2023-03-31 14:23:45 -07:00
Tri Dao
993d12448e Implement GPT-NeoX 2023-03-29 01:21:25 -07:00
Tri Dao
4d87e4d875 Implement GPT-J 2023-03-22 16:16:58 -07:00