Commit Graph

9 Commits

Author SHA1 Message Date
Tri Dao
e45a46a5b7 [Rotary] Implement GPT-J style (interleaved) rotary 2023-03-14 14:35:53 -07:00
Tri Dao
85b8e3d334 [Docs] Mention that XPos's scale_base is recommended to be 512 2022-12-29 20:25:02 -08:00
Tri Dao
1e712ea8b0 Implement TensorParallel for MHA 2022-12-25 11:39:55 -08:00
Tri Dao
496e4f528c Implement XPos (Sun et al.) 2022-12-21 14:17:58 -08:00
Alexander Ploshkin
ee8984d2be add asserts for sin shape 2022-12-17 13:34:57 +04:00
Alexander Ploshkin
c7c66976cc fix slicing dimensions 2022-12-16 15:39:06 +04:00
Alexander Ploshkin
96656b9323 Remove redundant shape asserts in rotary embeddings 2022-12-15 18:13:21 +04:00
Tri Dao
71f674ae23 [Rotary] Customize base, support seqlen_offset 2022-11-17 11:43:36 -08:00
Tri Dao
d4b320b31f Add MLP, MHA, Block, Embedding modules 2022-11-13 22:06:44 -08:00