[Misc][Speculative decoding] Typos and typing fixes (#6467)
Co-authored-by: caishangming.csm <caishangming.csm@alibaba-inc.com>
This commit is contained in:
parent
10383887e0
commit
a19e8d3726
@ -43,7 +43,7 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def set_include_gpu_probs_tensor(self) -> None:
|
def set_include_gpu_probs_tensor(self) -> None:
|
||||||
# Need include_gpu_probs_tensor for multi_step_worker
|
# Need include_gpu_probs_tensor for MultiStepWorker
|
||||||
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
self.model_runner.model.sampler.include_gpu_probs_tensor = True
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from vllm.worker.worker_base import LoraNotSupportedWorkerBase
|
|||||||
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
||||||
"""NGramWorker provides a light drafter without need for model.
|
"""NGramWorker provides a light drafter without need for model.
|
||||||
|
|
||||||
Current NGramWorker only implement prompt lookup decoding,
|
Current NGramWorker only implements prompt lookup decoding,
|
||||||
and in future we may also do RAG type drafter and other scenarios
|
and in future we may also do RAG type drafter and other scenarios
|
||||||
which don't rely on LLM model to give proposals.
|
which don't rely on LLM model to give proposals.
|
||||||
"""
|
"""
|
||||||
@ -37,7 +37,7 @@ class NGramWorker(NonLLMProposerWorkerBase, LoraNotSupportedWorkerBase):
|
|||||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||||
self.load_model = lambda *args, **kwargs: None
|
self.load_model = lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Current only support Top1Proposer
|
# Current NGramWorker only supports Top1Proposer
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
weakref.proxy(self), # type: ignore[arg-type]
|
weakref.proxy(self), # type: ignore[arg-type]
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
|
|||||||
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
) -> Tuple[Optional[List[SamplerOutput]], bool]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_include_gpu_probs_tensor(self):
|
def set_include_gpu_probs_tensor(self) -> None:
|
||||||
"""Implementation optional"""
|
"""Implementation optional"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@ -206,7 +206,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
|
|
||||||
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
self.probs_dtype = self.spec_decode_sampler.probs_dtype
|
||||||
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
self.token_id_dtype = self.spec_decode_sampler.token_id_dtype
|
||||||
# Lazy initiazliation.
|
# Lazy initialization.
|
||||||
self.scorer: SpeculativeScorer
|
self.scorer: SpeculativeScorer
|
||||||
|
|
||||||
# Hidden states from target model to pass to proposer
|
# Hidden states from target model to pass to proposer
|
||||||
|
|||||||
@ -138,7 +138,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
|
|
||||||
# Currently only proposal lens of 0 or the global batch proposal len
|
# Currently only proposal lens of 0 or the global batch proposal len
|
||||||
# are supported.
|
# are supported.
|
||||||
# If max_proposal_len is defined, then we shall no exccess this
|
# If max_proposal_len is defined, then we shall no exceed this
|
||||||
# quota for nonzero_proposal
|
# quota for nonzero_proposal
|
||||||
new_k = 0
|
new_k = 0
|
||||||
if (self.max_proposal_len is None
|
if (self.max_proposal_len is None
|
||||||
@ -219,7 +219,7 @@ class Top1Proposer(SpeculativeProposer):
|
|||||||
proposal_lens: List[int],
|
proposal_lens: List[int],
|
||||||
nonzero_proposal_len_indices: List[int],
|
nonzero_proposal_len_indices: List[int],
|
||||||
sampler_transposed: bool,
|
sampler_transposed: bool,
|
||||||
) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""After speculations are produced, merge the speculation results with
|
"""After speculations are produced, merge the speculation results with
|
||||||
the skipped sequences.
|
the skipped sequences.
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user