[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)

This commit is contained in:
William Lin 2024-08-08 22:42:45 -07:00 committed by GitHub
parent 99b4cf5f23
commit 57b7be0e1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 52 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import itertools
import random
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
import torch
@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0]
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False
sampling_params = SamplingParams(temperature=0)
mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None

View File

@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights(
self,
max_loras: int,

View File

@ -51,6 +51,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
def _init_sampling_tensors(
self,
@ -177,8 +178,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
return self.should_modify_greedy_probs_inplace
def _get_bin_counts_and_mask(

View File

@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_include_gpu_probs_tensor(self):
pass
def set_should_modify_greedy_probs_inplace(self):
pass
@torch.inference_mode()
def sampler_output(
self,

View File

@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
@torch.inference_mode()
def sampler_output(
self,

View File

@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional"""
pass
def set_should_modify_greedy_probs_inplace(self) -> None:
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache"""

View File

@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor()
def set_should_modify_greedy_probs_inplace(self) -> None:
if self._is_dummy:
return
self._worker.set_should_modify_greedy_probs_inplace()
def load_model(self) -> None:
if self._is_dummy:
return

View File

@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
"""
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use.