[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)
This commit is contained in:
parent
99b4cf5f23
commit
57b7be0e1c
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
|
||||
|
||||
assert tokens1[0] == tokens2[1]
|
||||
assert tokens1[1] == tokens2[0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_include_gpu_probs_tensor(device: str):
|
||||
set_random_seed(42)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
_, fake_logits, sampler = _prepare_test(batch_size)
|
||||
sampler.include_gpu_probs_tensor = True
|
||||
sampler.should_modify_greedy_probs_inplace = False
|
||||
|
||||
sampling_params = SamplingParams(temperature=0)
|
||||
|
||||
mock_inplace = Mock()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
|
||||
mock_inplace):
|
||||
|
||||
sampler_output = _do_sample(batch_size, fake_logits, sampler,
|
||||
sampling_params, device)
|
||||
mock_inplace.assert_not_called()
|
||||
|
||||
assert sampler_output.sampled_token_probs is not None
|
||||
assert sampler_output.logprobs is not None
|
||||
assert sampler_output.sampled_token_ids is not None
|
||||
|
||||
@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
def include_gpu_probs_tensor(self):
|
||||
return self.base_layer.include_gpu_probs_tensor
|
||||
|
||||
@property
|
||||
def should_modify_greedy_probs_inplace(self):
|
||||
return self.base_layer.should_modify_greedy_probs_inplace
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
|
||||
@ -51,6 +51,7 @@ class Sampler(nn.Module):
|
||||
# containing the sampled token ids and probabilities. This is used by
|
||||
# speculative decoding.
|
||||
self.include_gpu_probs_tensor = False
|
||||
self.should_modify_greedy_probs_inplace = False
|
||||
|
||||
def _init_sampling_tensors(
|
||||
self,
|
||||
@ -177,8 +178,7 @@ class Sampler(nn.Module):
|
||||
This is used by speculative decoding, which requires that the sampling
|
||||
method be encoded into the probability distribution.
|
||||
"""
|
||||
# Modify greedy probs if include_gpu_probs_tensor is set.
|
||||
return self.include_gpu_probs_tensor
|
||||
return self.should_modify_greedy_probs_inplace
|
||||
|
||||
|
||||
def _get_bin_counts_and_mask(
|
||||
|
||||
@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
|
||||
def set_include_gpu_probs_tensor(self):
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self):
|
||||
pass
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
|
||||
@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
|
||||
True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def sampler_output(
|
||||
self,
|
||||
|
||||
@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
"""Implementation optional"""
|
||||
pass
|
||||
|
||||
|
||||
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
|
||||
"""Proposer worker which does not use a model with kvcache"""
|
||||
|
||||
@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||
# Need include_gpu_probs_tensor for multi_step_worker
|
||||
self._worker.set_include_gpu_probs_tensor()
|
||||
|
||||
def set_should_modify_greedy_probs_inplace(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
self._worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self._is_dummy:
|
||||
return
|
||||
|
||||
@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||
"""
|
||||
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
|
||||
) = True
|
||||
(self.scorer_worker.model_runner.model.sampler.
|
||||
should_modify_greedy_probs_inplace) = True
|
||||
self.proposer_worker.set_include_gpu_probs_tensor()
|
||||
self.proposer_worker.set_should_modify_greedy_probs_inplace()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of cache blocks to use.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user