Add model_utils
This commit is contained in:
parent
bb59a3e730
commit
fffa2e1f4b
@ -1,5 +1,7 @@
|
||||
from cacheflow.worker.models.opt import OPTForCausalLM
|
||||
from cacheflow.worker.models.model_utils import get_model
|
||||
|
||||
|
||||
__all__ = [
|
||||
'OPTForCausalLM',
|
||||
'get_model',
|
||||
|
||||
]
|
||||
|
||||
13
cacheflow/worker/models/model_utils.py
Normal file
13
cacheflow/worker/models/model_utils.py
Normal file
@ -0,0 +1,13 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow.worker.models.opt import OPTForCausalLM
|
||||
|
||||
MODEL_CLASSES = {
|
||||
'opt': OPTForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
def get_model(model_name: str) -> nn.Module:
|
||||
if model_name not in MODEL_CLASSES:
|
||||
raise ValueError(f'Invalid model name: {model_name}')
|
||||
return MODEL_CLASSES[model_name].from_pretrained(model_name)
|
||||
Loading…
Reference in New Issue
Block a user