From fffa2e1f4b7534d5f86e900838d9a24dfba307c9 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 13 Feb 2023 09:36:12 +0000 Subject: [PATCH] Add model_utils --- cacheflow/worker/models/__init__.py | 6 ++++-- cacheflow/worker/models/model_utils.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 cacheflow/worker/models/model_utils.py diff --git a/cacheflow/worker/models/__init__.py b/cacheflow/worker/models/__init__.py index 4736c477..f0c68e5b 100644 --- a/cacheflow/worker/models/__init__.py +++ b/cacheflow/worker/models/__init__.py @@ -1,5 +1,7 @@ -from cacheflow.worker.models.opt import OPTForCausalLM +from cacheflow.worker.models.model_utils import get_model + __all__ = [ - 'OPTForCausalLM', + 'get_model', + ] diff --git a/cacheflow/worker/models/model_utils.py b/cacheflow/worker/models/model_utils.py new file mode 100644 index 00000000..a98eac04 --- /dev/null +++ b/cacheflow/worker/models/model_utils.py @@ -0,0 +1,13 @@ +import torch.nn as nn + +from cacheflow.worker.models.opt import OPTForCausalLM + +MODEL_CLASSES = { + 'opt': OPTForCausalLM, +} + + +def get_model(model_name: str) -> nn.Module: + if model_name not in MODEL_CLASSES: + raise ValueError(f'Invalid model name: {model_name}') + return MODEL_CLASSES[model_name].from_pretrained(model_name)