Fix get_model
This commit is contained in:
parent
531e1c74e8
commit
ee9442518d
@ -8,6 +8,7 @@ MODEL_CLASSES = {
|
|||||||
|
|
||||||
|
|
||||||
def get_model(model_name: str) -> nn.Module:
|
def get_model(model_name: str) -> nn.Module:
|
||||||
if model_name not in MODEL_CLASSES:
|
for model_class, model in MODEL_CLASSES.items():
|
||||||
raise ValueError(f'Invalid model name: {model_name}')
|
if model_class in model_name:
|
||||||
return MODEL_CLASSES[model_name].from_pretrained(model_name)
|
return model.from_pretrained(model_name)
|
||||||
|
raise ValueError(f'Invalid model name: {model_name}')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user