Add model_utils
This commit is contained in:
parent
bb59a3e730
commit
fffa2e1f4b
@ -1,5 +1,7 @@
|
|||||||
from cacheflow.worker.models.opt import OPTForCausalLM
|
from cacheflow.worker.models.model_utils import get_model
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'OPTForCausalLM',
|
'get_model',
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|||||||
13
cacheflow/worker/models/model_utils.py
Normal file
13
cacheflow/worker/models/model_utils.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from cacheflow.worker.models.opt import OPTForCausalLM
|
||||||
|
|
||||||
|
MODEL_CLASSES = {
|
||||||
|
'opt': OPTForCausalLM,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_name: str) -> nn.Module:
|
||||||
|
if model_name not in MODEL_CLASSES:
|
||||||
|
raise ValueError(f'Invalid model name: {model_name}')
|
||||||
|
return MODEL_CLASSES[model_name].from_pretrained(model_name)
|
||||||
Loading…
Reference in New Issue
Block a user