diff --git a/cacheflow/worker/models/__init__.py b/cacheflow/worker/models/__init__.py index 4736c477..f0c68e5b 100644 --- a/cacheflow/worker/models/__init__.py +++ b/cacheflow/worker/models/__init__.py @@ -1,5 +1,7 @@ -from cacheflow.worker.models.opt import OPTForCausalLM +from cacheflow.worker.models.model_utils import get_model + __all__ = [ - 'OPTForCausalLM', + 'get_model', + ] diff --git a/cacheflow/worker/models/model_utils.py b/cacheflow/worker/models/model_utils.py new file mode 100644 index 00000000..a98eac04 --- /dev/null +++ b/cacheflow/worker/models/model_utils.py @@ -0,0 +1,13 @@ +import torch.nn as nn + +from cacheflow.worker.models.opt import OPTForCausalLM + +MODEL_CLASSES = { + 'opt': OPTForCausalLM, +} + + +def get_model(model_name: str) -> nn.Module: + if model_name not in MODEL_CLASSES: + raise ValueError(f'Invalid model name: {model_name}') + return MODEL_CLASSES[model_name].from_pretrained(model_name)