diff --git a/cacheflow/models/model_utils.py b/cacheflow/models/model_utils.py index 3878b87e..522630d7 100644 --- a/cacheflow/models/model_utils.py +++ b/cacheflow/models/model_utils.py @@ -1,6 +1,6 @@ import torch.nn as nn -from cacheflow.worker.models.opt import OPTForCausalLM +from cacheflow.models.opt import OPTForCausalLM MODEL_CLASSES = { 'opt': OPTForCausalLM,