[Speculative Decoding] Medusa Implementation with Top-1 proposer (#4978)
This commit is contained in:
parent
d3a245138a
commit
2416b26e11
226
tests/spec_decode/e2e/test_medusa_correctness.py
Normal file
226
tests/spec_decode/e2e/test_medusa_correctness.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
"""This docstring details important information on the testing methodology.
|
||||||
|
|
||||||
|
Most of the tests rely on "greedy equality", where we expect the output of
|
||||||
|
speculative decoding on a sequence to exactly match the output of normal non-
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Since speculative decoding with rejection sampling guarantees that the output
|
||||||
|
distribution matches the target model's output distribution (up to hardware
|
||||||
|
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||||
|
equality.
|
||||||
|
|
||||||
|
However, we still need to verify below scenario could be passed:
|
||||||
|
* Batch size 1 greedy equality
|
||||||
|
* Batch size >1 greedy equality
|
||||||
|
* Test greedy equality under preemption
|
||||||
|
* Test greedy equality under various number of speculative tokens.
|
||||||
|
|
||||||
|
With those tests, we can say at least, Medusa would not break the
|
||||||
|
correctess for the target model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .conftest import run_greedy_equality_correctness_test
|
||||||
|
|
||||||
|
# main model
|
||||||
|
# lmsys/vicuna-7b-v1.3 was to be used but it's causing
|
||||||
|
# OOM in CI pipeline, so using a smaller model.
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
# speculative model
|
||||||
|
SPEC_MODEL = "abhigoyal/vllm-medusa-llama-68m-random"
|
||||||
|
|
||||||
|
# max. number of speculative tokens: this corresponds to
|
||||||
|
# num_heads in the config.json of the speculator model.
|
||||||
|
MAX_SPEC_TOKENS = 5
|
||||||
|
|
||||||
|
# precision
|
||||||
|
PRECISION = "float32"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify greedy equality with different batch size."""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"block_size": 8,
|
||||||
|
# 2 for small prompt, 256//8 for generated.
|
||||||
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
}
|
||||||
|
# Try a range of num. speculative tokens
|
||||||
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"speculative_disable_by_batch_size": 4
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify that mlp speculative decoding produces exact equality
|
||||||
|
to without spec decode when speculation is disabled for large
|
||||||
|
batch sizes.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
pytest.main([__file__])
|
||||||
@ -64,6 +64,7 @@ _GENERATION_MODELS = {
|
|||||||
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
|
||||||
}
|
}
|
||||||
|
|||||||
159
vllm/model_executor/models/medusa.py
Normal file
159
vllm/model_executor/models/medusa.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, num_layers: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
nn.Linear(hidden_size, hidden_size, bias=False)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
for layer in self.layers:
|
||||||
|
x = x + self.act(layer(x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Medusa(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: MedusaConfig, **_) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
ResidualBlock(hidden_size=self.config.hidden_size,
|
||||||
|
num_layers=self.config.num_hidden_layers)
|
||||||
|
for _ in range(self.config.num_heads)
|
||||||
|
])
|
||||||
|
self.orig_vocab_size = config.vocab_size
|
||||||
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
|
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||||
|
|
||||||
|
self.lm_heads = nn.ModuleList([
|
||||||
|
ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=self.truncated_vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
) for _ in range(self.config.num_heads)
|
||||||
|
])
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
self.truncated_vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
|
||||||
|
self.token_map = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
return [block(hidden_states) for block in self.blocks]
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self, hidden_states: List[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
|
||||||
|
logits = []
|
||||||
|
|
||||||
|
for hs, lm_head in zip(hidden_states, self.lm_heads):
|
||||||
|
_logits = self.logits_processor(lm_head, hs, sampling_metadata)
|
||||||
|
|
||||||
|
if self.token_map is None:
|
||||||
|
logits.append(_logits)
|
||||||
|
else:
|
||||||
|
logits.append(-torch.inf * torch.ones(
|
||||||
|
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||||
|
device=_logits.device,
|
||||||
|
dtype=_logits.dtype))
|
||||||
|
|
||||||
|
logits[-1][..., self.token_map] = _logits
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: List[torch.Tensor],
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
logits = torch.stack(logits, dim=0).float()
|
||||||
|
logprobs = torch.log_softmax(logits, dim=-1)
|
||||||
|
token_ids = logits.argmax(-1) # support only top-1 for now
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
|
||||||
|
token_id_list = []
|
||||||
|
token_prob_list = []
|
||||||
|
token_logprob_list = []
|
||||||
|
|
||||||
|
for idx, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||||
|
token_id_list.append(token_ids[:, seq_group.sample_indices])
|
||||||
|
token_prob_list.append(probs[:, seq_group.sample_indices])
|
||||||
|
token_logprob_list.append(logprobs[:, seq_group.sample_indices])
|
||||||
|
|
||||||
|
outputs: List[Optional[SamplerOutput]] = []
|
||||||
|
for idx in range(len(sampling_metadata.seq_groups)):
|
||||||
|
outputs.append(
|
||||||
|
SamplerOutput(
|
||||||
|
outputs=None,
|
||||||
|
sampled_token_probs=token_prob_list[idx].squeeze(1),
|
||||||
|
logprobs=token_logprob_list[idx].squeeze(1),
|
||||||
|
sampled_token_ids=token_id_list[idx].squeeze(1),
|
||||||
|
))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def generate_proposals(
|
||||||
|
self,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
return self.sample(
|
||||||
|
logits=self.compute_logits(
|
||||||
|
hidden_states=self.forward(previous_hidden_states),
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
),
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
|
||||||
|
weights_map = {}
|
||||||
|
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
name = name.replace("medusa_heads.", "")
|
||||||
|
|
||||||
|
if name == "token_map":
|
||||||
|
if self.truncated_vocab_size < self.orig_vocab_size:
|
||||||
|
self.token_map = nn.Parameter(loaded_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
elif name in params_dict:
|
||||||
|
weights_map[name] = loaded_weight
|
||||||
|
|
||||||
|
for name, loaded_weight in weights_map.items():
|
||||||
|
if "lm_head" in name and self.token_map is not None and\
|
||||||
|
loaded_weight.shape[0] > self.token_map.shape[0]:
|
||||||
|
|
||||||
|
loaded_weight = loaded_weight[self.token_map]
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
if self.token_map is not None:
|
||||||
|
self.token_map.to(device=self.lm_heads[0].weight.device)
|
||||||
|
|
||||||
|
assert (self.truncated_vocab_size
|
||||||
|
== self.orig_vocab_size) or (self.token_map is not None)
|
||||||
127
vllm/spec_decode/medusa_worker.py
Normal file
127
vllm/spec_decode/medusa_worker.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
import weakref
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
from vllm.spec_decode.interfaces import SpeculativeProposals
|
||||||
|
from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
||||||
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||||
|
"""Worker for Medusa.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
# Lazy initialization list.
|
||||||
|
self._proposer: Top1Proposer
|
||||||
|
|
||||||
|
def init_device(self):
|
||||||
|
super().init_device()
|
||||||
|
|
||||||
|
self._proposer = Top1Proposer(
|
||||||
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
|
self.device,
|
||||||
|
self.vocab_size,
|
||||||
|
max_proposal_len=self.max_model_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_include_gpu_probs_tensor(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sampler_output(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
sample_len: int,
|
||||||
|
) -> Tuple[List[SamplerOutput], bool]:
|
||||||
|
"""Run the model forward pass to generate sample_len future tokens.
|
||||||
|
Returns the list of sampler output, one per layer, along with indicator
|
||||||
|
of whether torch tensor in sampler output need to be transposed in
|
||||||
|
latter sampler_output_to_torch logic.
|
||||||
|
|
||||||
|
For medusa worker, this indicator shall be False.
|
||||||
|
"""
|
||||||
|
self._raise_if_unsupported(execute_model_req)
|
||||||
|
|
||||||
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
|
seq_lens, query_lens = self._prepare_input_tensors(
|
||||||
|
seq_group_metadata_list)
|
||||||
|
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list, seq_lens, query_lens, self.device,
|
||||||
|
self.model_runner.pin_memory)
|
||||||
|
|
||||||
|
model_outputs = self.model_runner.model.generate_proposals(
|
||||||
|
previous_hidden_states=execute_model_req.previous_hidden_states.
|
||||||
|
hidden_states,
|
||||||
|
sampling_metadata=sampling_metadata)
|
||||||
|
|
||||||
|
return model_outputs, False
|
||||||
|
|
||||||
|
def _prepare_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
if not seq_group_metadata_list:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
query_lens: List[int] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
is_prompt = seq_group_metadata.is_prompt
|
||||||
|
|
||||||
|
for seq_data in seq_group_metadata.seq_data.values():
|
||||||
|
seq_data_len = seq_data.get_len()
|
||||||
|
if is_prompt:
|
||||||
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = min(
|
||||||
|
seq_data_len,
|
||||||
|
context_len + seq_group_metadata.token_chunk_size)
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
query_lens.append(seq_len - context_len)
|
||||||
|
else:
|
||||||
|
seq_lens.append(seq_data_len)
|
||||||
|
query_lens.append(1)
|
||||||
|
|
||||||
|
return seq_lens, query_lens
|
||||||
|
|
||||||
|
def get_spec_proposals(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> SpeculativeProposals:
|
||||||
|
"""Produce speculations given an input batch of sequences. The number of
|
||||||
|
speculative tokens per sequence is determined by max_proposal_len.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._proposer.get_spec_proposals(execute_model_req)
|
||||||
|
|
||||||
|
def _raise_if_unsupported(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> None:
|
||||||
|
"""MedusaWorker does not yet implement support for cache swap
|
||||||
|
operations or beam search.
|
||||||
|
"""
|
||||||
|
if any([
|
||||||
|
execute_model_req.blocks_to_swap_in,
|
||||||
|
execute_model_req.blocks_to_swap_out,
|
||||||
|
execute_model_req.blocks_to_copy
|
||||||
|
]):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MedusaWorker does not support cache operations")
|
||||||
|
|
||||||
|
if any(
|
||||||
|
len(seq_group_metadata.seq_data.keys()) != 1
|
||||||
|
for seq_group_metadata in
|
||||||
|
execute_model_req.seq_group_metadata_list):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"MedusaWorker does not support beam search.")
|
||||||
@ -18,6 +18,7 @@ from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
|
|||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeScorer, SpeculativeScores)
|
SpeculativeScorer, SpeculativeScores)
|
||||||
|
from vllm.spec_decode.medusa_worker import MedusaWorker
|
||||||
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
from vllm.spec_decode.metrics import AsyncMetricsCollector
|
||||||
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
@ -129,6 +130,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
"model_config"].hf_config.model_type == "mlp_speculator":
|
"model_config"].hf_config.model_type == "mlp_speculator":
|
||||||
disable_bonus_tokens = False
|
disable_bonus_tokens = False
|
||||||
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs)
|
||||||
|
elif draft_worker_kwargs[
|
||||||
|
"model_config"].hf_config.model_type == "medusa":
|
||||||
|
disable_bonus_tokens = False
|
||||||
|
proposer_worker = MedusaWorker(**draft_worker_kwargs)
|
||||||
else:
|
else:
|
||||||
if draft_tp == 1:
|
if draft_tp == 1:
|
||||||
draft_worker_kwargs[
|
draft_worker_kwargs[
|
||||||
|
|||||||
@ -6,8 +6,9 @@ from transformers import GenerationConfig, PretrainedConfig
|
|||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||||
JAISConfig, MLPSpeculatorConfig,
|
JAISConfig, MedusaConfig,
|
||||||
MPTConfig, RWConfig)
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
|
RWConfig)
|
||||||
|
|
||||||
if VLLM_USE_MODELSCOPE:
|
if VLLM_USE_MODELSCOPE:
|
||||||
from modelscope import AutoConfig
|
from modelscope import AutoConfig
|
||||||
@ -24,6 +25,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||||
"jais": JAISConfig,
|
"jais": JAISConfig,
|
||||||
"mlp_speculator": MLPSpeculatorConfig,
|
"mlp_speculator": MLPSpeculatorConfig,
|
||||||
|
"medusa": MedusaConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, cls in _CONFIG_REGISTRY.items():
|
for name, cls in _CONFIG_REGISTRY.items():
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||||
|
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||||
|
|
||||||
@ -14,5 +15,6 @@ __all__ = [
|
|||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
|
"MedusaConfig",
|
||||||
"MLPSpeculatorConfig",
|
"MLPSpeculatorConfig",
|
||||||
]
|
]
|
||||||
|
|||||||
60
vllm/transformers_utils/configs/medusa.py
Normal file
60
vllm/transformers_utils/configs/medusa.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MedusaConfig(PretrainedConfig):
|
||||||
|
model_type = "medusa"
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size: int = 4096,
|
||||||
|
vocab_size: int = 32001,
|
||||||
|
num_heads: int = 5,
|
||||||
|
num_hidden_layers: int = 1,
|
||||||
|
max_paths: int = 64,
|
||||||
|
topk: int = 10,
|
||||||
|
truncated_vocab_size: Optional[int] = None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.max_paths = max_paths
|
||||||
|
self.topk = topk
|
||||||
|
self.max_seq_len = int(2**20)
|
||||||
|
self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\
|
||||||
|
else truncated_vocab_size
|
||||||
|
if "architectures" not in kwargs:
|
||||||
|
kwargs["architectures"] = ["MedusaModel"]
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
**kwargs,
|
||||||
|
) -> "MedusaConfig":
|
||||||
|
config_dict, kwargs = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, **kwargs)
|
||||||
|
for k in list(config_dict.keys()):
|
||||||
|
if 'num' in k:
|
||||||
|
if 'heads' in k:
|
||||||
|
config_dict["num_heads"] = config_dict.pop(k)
|
||||||
|
elif 'layers' in k:
|
||||||
|
config_dict["num_hidden_layers"] = config_dict.pop(k)
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_attention_heads(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_lookahead_tokens(self):
|
||||||
|
return self.num_heads
|
||||||
|
|
||||||
|
@num_lookahead_tokens.setter
|
||||||
|
def num_lookahead_tokens(self, num_lookahead_tokens: int):
|
||||||
|
self.num_heads = num_lookahead_tokens
|
||||||
@ -78,8 +78,9 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
speculative_args = {} if speculative_config is None \
|
speculative_args = {} if speculative_config is None \
|
||||||
or (speculative_config.draft_model_config.model ==
|
or (speculative_config.draft_model_config.model ==
|
||||||
model_config.model) \
|
model_config.model) \
|
||||||
or (speculative_config.draft_model_config.hf_config.model_type !=
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
"mlp_speculator") else {"return_hidden_states": True}
|
not in ["medusa", "mlp_speculator"]) \
|
||||||
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if model_runner_cls is not None:
|
if model_runner_cls is not None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user