[Speculative Decoding] Medusa Implementation with Top-1 proposer (#4978)

This commit is contained in:
Abhinav Goyal 2024-07-10 07:04:02 +05:30 committed by GitHub
parent d3a245138a
commit 2416b26e11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 587 additions and 4 deletions

View File

@ -0,0 +1,226 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various number of speculative tokens.
With those tests, we can say at least, Medusa would not break the
correctess for the target model outputs.
"""
import pytest
from .conftest import run_greedy_equality_correctness_test
# main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model.
MAIN_MODEL = "JackFram/llama-68m"
# speculative model
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
# max. number of speculative tokens: this corresponds to
# num_heads in the config.json of the speculator model.
MAX_SPEC_TOKENS = 5
# precision
PRECISION = "float32"
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize("output_len", [
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": k,
}
# Try a range of num. speculative tokens
for k in range(1, 1 + MAX_SPEC_TOKENS)
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Precision
"dtype": PRECISION,
# Main model
"model": MAIN_MODEL,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": SPEC_MODEL,
"num_speculative_tokens": MAX_SPEC_TOKENS,
"speculative_disable_by_batch_size": 4
}])
@pytest.mark.parametrize("batch_size", [1, 5])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
if __name__ == "__main__":
import pytest
pytest.main([__file__])

View File

@ -64,6 +64,7 @@ _GENERATION_MODELS = {
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
}

View File

@ -0,0 +1,159 @@
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.medusa import MedusaConfig
class ResidualBlock(nn.Module):
def __init__(self, hidden_size: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([
nn.Linear(hidden_size, hidden_size, bias=False)
for _ in range(num_layers)
])
self.act = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
x = x + self.act(layer(x))
return x
class Medusa(nn.Module):
def __init__(self, config: MedusaConfig, **_) -> None:
super().__init__()
self.config = config
self.blocks = nn.ModuleList([
ResidualBlock(hidden_size=self.config.hidden_size,
num_layers=self.config.num_hidden_layers)
for _ in range(self.config.num_heads)
])
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_heads = nn.ModuleList([
ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) for _ in range(self.config.num_heads)
])
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
self.token_map = None
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
return [block(hidden_states) for block in self.blocks]
def compute_logits(
self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits = []
for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
if self.token_map is None:
logits.append(_logits)
else:
logits.append(-torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype))
logits[-1][..., self.token_map] = _logits
return logits
def sample(
self,
logits: List[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
logits = torch.stack(logits, dim=0).float()
logprobs = torch.log_softmax(logits, dim=-1)
token_ids = logits.argmax(-1) # support only top-1 for now
probs = torch.softmax(logits, dim=-1)
token_id_list = []
token_prob_list = []
token_logprob_list = []
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
token_id_list.append(token_ids[:, seq_group.sample_indices])
token_prob_list.append(probs[:, seq_group.sample_indices])
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
outputs: List[Optional[SamplerOutput]] = []
for idx in range(len(sampling_metadata.seq_groups)):
outputs.append(
SamplerOutput(
outputs=None,
sampled_token_probs=token_prob_list[idx].squeeze(1),
logprobs=token_logprob_list[idx].squeeze(1),
sampled_token_ids=token_id_list[idx].squeeze(1),
))
return outputs
def generate_proposals(
self,
previous_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> List[SamplerOutput]:
return self.sample(
logits=self.compute_logits(
hidden_states=self.forward(previous_hidden_states),
sampling_metadata=sampling_metadata,
),
sampling_metadata=sampling_metadata,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
weights_map = {}
for name, loaded_weight in weights:
name = name.replace("medusa_heads.", "")
if name == "token_map":
if self.truncated_vocab_size < self.orig_vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name in params_dict:
weights_map[name] = loaded_weight
for name, loaded_weight in weights_map.items():
if "lm_head" in name and self.token_map is not None and\
loaded_weight.shape[0] > self.token_map.shape[0]:
loaded_weight = loaded_weight[self.token_map]
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.token_map is not None:
self.token_map.to(device=self.lm_heads[0].weight.device)
assert (self.truncated_vocab_size
== self.orig_vocab_size) or (self.token_map is not None)

View File

@ -0,0 +1,127 @@
import weakref
from typing import List, Optional, Tuple
import torch
from vllm.model_executor import SamplingMetadata
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
"""Worker for Medusa.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Lazy initialization list.
self._proposer: Top1Proposer
def init_device(self):
super().init_device()
self._proposer = Top1Proposer(
weakref.proxy(self), # type: ignore[arg-type]
self.device,
self.vocab_size,
max_proposal_len=self.max_model_len,
)
def set_include_gpu_probs_tensor(self):
pass
@torch.inference_mode()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,
sample_len: int,
) -> Tuple[List[SamplerOutput], bool]:
"""Run the model forward pass to generate sample_len future tokens.
Returns the list of sampler output, one per layer, along with indicator
of whether torch tensor in sampler output need to be transposed in
latter sampler_output_to_torch logic.
For medusa worker, this indicator shall be False.
"""
self._raise_if_unsupported(execute_model_req)
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
seq_lens, query_lens = self._prepare_input_tensors(
seq_group_metadata_list)
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_lens, query_lens, self.device,
self.model_runner.pin_memory)
model_outputs = self.model_runner.model.generate_proposals(
previous_hidden_states=execute_model_req.previous_hidden_states.
hidden_states,
sampling_metadata=sampling_metadata)
return model_outputs, False
def _prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[List[int], List[int]]:
if not seq_group_metadata_list:
return [], []
seq_lens: List[int] = []
query_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
is_prompt = seq_group_metadata.is_prompt
for seq_data in seq_group_metadata.seq_data.values():
seq_data_len = seq_data.get_len()
if is_prompt:
context_len = seq_data.get_num_computed_tokens()
seq_len = min(
seq_data_len,
context_len + seq_group_metadata.token_chunk_size)
seq_lens.append(seq_len)
query_lens.append(seq_len - context_len)
else:
seq_lens.append(seq_data_len)
query_lens.append(1)
return seq_lens, query_lens
def get_spec_proposals(
self,
execute_model_req: ExecuteModelRequest,
) -> SpeculativeProposals:
"""Produce speculations given an input batch of sequences. The number of
speculative tokens per sequence is determined by max_proposal_len.
"""
return self._proposer.get_spec_proposals(execute_model_req)
def _raise_if_unsupported(
self,
execute_model_req: ExecuteModelRequest,
) -> None:
"""MedusaWorker does not yet implement support for cache swap
operations or beam search.
"""
if any([
execute_model_req.blocks_to_swap_in,
execute_model_req.blocks_to_swap_out,
execute_model_req.blocks_to_copy
]):
raise NotImplementedError(
"MedusaWorker does not support cache operations")
if any(
len(seq_group_metadata.seq_data.keys()) != 1
for seq_group_metadata in
execute_model_req.seq_group_metadata_list):
raise NotImplementedError(
"MedusaWorker does not support beam search.")

View File

@ -18,6 +18,7 @@ from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.medusa_worker import MedusaWorker
from vllm.spec_decode.metrics import AsyncMetricsCollector
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
from vllm.spec_decode.multi_step_worker import MultiStepWorker
@ -129,6 +130,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"model_config"].hf_config.model_type == "mlp_speculator":
disable_bonus_tokens = False
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
elif draft_worker_kwargs[
"model_config"].hf_config.model_type == "medusa":
disable_bonus_tokens = False
proposer_worker = MedusaWorker(**draft_worker_kwargs)
else:
if draft_tp == 1:
draft_worker_kwargs[

View File

@ -6,8 +6,9 @@ from transformers import GenerationConfig, PretrainedConfig
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MLPSpeculatorConfig,
MPTConfig, RWConfig)
JAISConfig, MedusaConfig,
MLPSpeculatorConfig, MPTConfig,
RWConfig)
if VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@ -24,6 +25,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"jais": JAISConfig,
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
}
for name, cls in _CONFIG_REGISTRY.items():

View File

@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
from vllm.transformers_utils.configs.mpt import MPTConfig
@ -14,5 +15,6 @@ __all__ = [
"MPTConfig",
"RWConfig",
"JAISConfig",
"MedusaConfig",
"MLPSpeculatorConfig",
]

View File

@ -0,0 +1,60 @@
import os
from typing import Optional, Union
from transformers import PretrainedConfig
class MedusaConfig(PretrainedConfig):
model_type = "medusa"
def __init__(self,
hidden_size: int = 4096,
vocab_size: int = 32001,
num_heads: int = 5,
num_hidden_layers: int = 1,
max_paths: int = 64,
topk: int = 10,
truncated_vocab_size: Optional[int] = None,
**kwargs):
self.hidden_size = hidden_size
self.vocab_size = vocab_size
self.num_heads = num_heads
self.num_hidden_layers = num_hidden_layers
self.max_paths = max_paths
self.topk = topk
self.max_seq_len = int(2**20)
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
else truncated_vocab_size
if "architectures" not in kwargs:
kwargs["architectures"] = ["MedusaModel"]
super().__init__(**kwargs)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "MedusaConfig":
config_dict, kwargs = cls.get_config_dict(
pretrained_model_name_or_path, **kwargs)
for k in list(config_dict.keys()):
if 'num' in k:
if 'heads' in k:
config_dict["num_heads"] = config_dict.pop(k)
elif 'layers' in k:
config_dict["num_hidden_layers"] = config_dict.pop(k)
return cls.from_dict(config_dict, **kwargs)
@property
def num_attention_heads(self):
return 0
@property
def num_lookahead_tokens(self):
return self.num_heads
@num_lookahead_tokens.setter
def num_lookahead_tokens(self, num_lookahead_tokens: int):
self.num_heads = num_lookahead_tokens

View File

@ -78,8 +78,9 @@ class Worker(LocalOrDistributedWorkerBase):
speculative_args = {} if speculative_config is None \
or (speculative_config.draft_model_config.model ==
model_config.model) \
or (speculative_config.draft_model_config.hf_config.model_type !=
"mlp_speculator") else {"return_hidden_states": True}
or (speculative_config.draft_model_config.hf_config.model_type
not in ["medusa", "mlp_speculator"]) \
else {"return_hidden_states": True}
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None: