diff --git a/cacheflow/worker/models/model_utils.py b/cacheflow/worker/models/model_utils.py index a98eac04..3878b87e 100644 --- a/cacheflow/worker/models/model_utils.py +++ b/cacheflow/worker/models/model_utils.py @@ -8,6 +8,7 @@ MODEL_CLASSES = { def get_model(model_name: str) -> nn.Module: - if model_name not in MODEL_CLASSES: - raise ValueError(f'Invalid model name: {model_name}') - return MODEL_CLASSES[model_name].from_pretrained(model_name) + for model_class, model in MODEL_CLASSES.items(): + if model_class in model_name: + return model.from_pretrained(model_name) + raise ValueError(f'Invalid model name: {model_name}')