Fix get_model
This commit is contained in:
parent
531e1c74e8
commit
ee9442518d
@ -8,6 +8,7 @@ MODEL_CLASSES = {
|
||||
|
||||
|
||||
def get_model(model_name: str) -> nn.Module:
|
||||
if model_name not in MODEL_CLASSES:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
return MODEL_CLASSES[model_name].from_pretrained(model_name)
|
||||
for model_class, model in MODEL_CLASSES.items():
|
||||
if model_class in model_name:
|
||||
return model.from_pretrained(model_name)
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user