[BugFix] Fix MLPSpeculator handling of num_speculative_tokens (#5876)

This commit is contained in:
Nick Hill 2024-06-27 10:59:33 -07:00 committed by GitHub
parent 3fd02bda51
commit 691e29ecf3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 10 deletions

View File

@ -920,15 +920,19 @@ class SpeculativeConfig:
max_logprobs=target_model_config.max_logprobs,
)
if (draft_model_config.hf_config.model_type == "mlp_speculator"
draft_hf_config = draft_model_config.hf_config
if (draft_hf_config.model_type == "mlp_speculator"
and target_parallel_config.world_size != 1):
# MLPSpeculator TP support will be added very soon
raise ValueError(
"Speculative decoding with mlp_speculator models does not "
"yet support distributed inferencing (TP > 1).")
n_predict = getattr(draft_model_config.hf_config, "n_predict",
None)
if (num_speculative_tokens is not None
and hasattr(draft_hf_config, "num_lookahead_tokens")):
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
n_predict = getattr(draft_hf_config, "n_predict", None)
if n_predict is not None:
if num_speculative_tokens is None:
# Default to max value defined in draft model config.

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig
class MLPSpeculatorLayerNorm(nn.Module):
@ -48,7 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
class MLPSpeculator(nn.Module):
def __init__(self, config, **kwargs) -> None:
def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None:
super().__init__()
self.n_predict = config.n_predict
self.vocab_size = config.vocab_size
@ -56,8 +57,7 @@ class MLPSpeculator(nn.Module):
self.inner_dim = config.inner_dim if config.inner_dim != 0 \
else config.emb_dim
self.max_speculative_tokens = getattr(config, "max_speculative_tokens",
self.n_predict)
self.max_speculative_tokens = config.num_lookahead_tokens
self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size,
@ -137,7 +137,8 @@ class MLPSpeculator(nn.Module):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name.replace("speculator.", "")]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
param = params_dict.get(name.replace("speculator.", ""))
if param is not None:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

View File

@ -35,6 +35,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
candidate tree.
For each candidate branch in the tree, head n produces topk[n]
additional sub-branches.
NOTE: This parameter is currently unused.
n_candidates: int
number of child candidates to create per sequence
"""
@ -47,4 +48,6 @@ class MLPSpeculatorConfig(PretrainedConfig):
self.n_predict = n_predict
self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict
super().__init__(**kwargs)