[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (#6971)

This commit is contained in:
William Lin 2024-08-08 22:42:45 -07:00 committed by GitHub
parent 99b4cf5f23
commit 57b7be0e1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 52 additions and 3 deletions

View File

@ -1,7 +1,7 @@
import itertools import itertools
import random import random
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
from unittest.mock import patch from unittest.mock import Mock, patch
import pytest import pytest
import torch import torch
@ -703,3 +703,28 @@ def test_sampler_repetition_penalty_mixed(device: str):
assert tokens1[0] == tokens2[1] assert tokens1[0] == tokens2[1]
assert tokens1[1] == tokens2[0] assert tokens1[1] == tokens2[0]
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_include_gpu_probs_tensor(device: str):
set_random_seed(42)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampler.include_gpu_probs_tensor = True
sampler.should_modify_greedy_probs_inplace = False
sampling_params = SamplingParams(temperature=0)
mock_inplace = Mock()
with patch(
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
mock_inplace):
sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device)
mock_inplace.assert_not_called()
assert sampler_output.sampled_token_probs is not None
assert sampler_output.logprobs is not None
assert sampler_output.sampled_token_ids is not None

View File

@ -1067,6 +1067,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
def include_gpu_probs_tensor(self): def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,

View File

@ -51,6 +51,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by # containing the sampled token ids and probabilities. This is used by
# speculative decoding. # speculative decoding.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
def _init_sampling_tensors( def _init_sampling_tensors(
self, self,
@ -177,8 +178,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution. method be encoded into the probability distribution.
""" """
# Modify greedy probs if include_gpu_probs_tensor is set. return self.should_modify_greedy_probs_inplace
return self.include_gpu_probs_tensor
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(

View File

@ -35,6 +35,9 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
def set_include_gpu_probs_tensor(self): def set_include_gpu_probs_tensor(self):
pass pass
def set_should_modify_greedy_probs_inplace(self):
pass
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,

View File

@ -46,6 +46,10 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Need include_gpu_probs_tensor for MultiStepWorker # Need include_gpu_probs_tensor for MultiStepWorker
self.model_runner.model.sampler.include_gpu_probs_tensor = True self.model_runner.model.sampler.include_gpu_probs_tensor = True
def set_should_modify_greedy_probs_inplace(self) -> None:
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
True)
@torch.inference_mode() @torch.inference_mode()
def sampler_output( def sampler_output(
self, self,

View File

@ -28,6 +28,10 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
"""Implementation optional""" """Implementation optional"""
pass pass
def set_should_modify_greedy_probs_inplace(self) -> None:
"""Implementation optional"""
pass
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
"""Proposer worker which does not use a model with kvcache""" """Proposer worker which does not use a model with kvcache"""

View File

@ -83,6 +83,12 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
# Need include_gpu_probs_tensor for multi_step_worker # Need include_gpu_probs_tensor for multi_step_worker
self._worker.set_include_gpu_probs_tensor() self._worker.set_include_gpu_probs_tensor()
def set_should_modify_greedy_probs_inplace(self) -> None:
if self._is_dummy:
return
self._worker.set_should_modify_greedy_probs_inplace()
def load_model(self) -> None: def load_model(self) -> None:
if self._is_dummy: if self._is_dummy:
return return

View File

@ -295,7 +295,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
""" """
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
) = True ) = True
(self.scorer_worker.model_runner.model.sampler.
should_modify_greedy_probs_inplace) = True
self.proposer_worker.set_include_gpu_probs_tensor() self.proposer_worker.set_include_gpu_probs_tensor()
self.proposer_worker.set_should_modify_greedy_probs_inplace()
def determine_num_available_blocks(self) -> Tuple[int, int]: def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of cache blocks to use. """Determine the number of cache blocks to use.