diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index fde44265..6aa43f22 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -4,13 +4,13 @@ from typing import Iterable, List, Tuple import torch import torch.nn as nn +from vllm.config import VllmConfig from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.transformers_utils.configs import MLPSpeculatorConfig SQRT2 = 2**0.5 @@ -65,8 +65,9 @@ class MLPSpeculator(nn.Module): https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite """ - def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() + config = vllm_config.model_config.hf_config self.n_predict = config.n_predict self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim