diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 55ce5aa6..b5e0da48 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -175,7 +175,7 @@ class Scheduler: num_curr_seqs += num_new_seqs scheduled.append(seq_group) - if scheduled: + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 859b6e15..74345430 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -294,14 +294,12 @@ class LLMEngine: def _schedule( self ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, - Optional[List[RequestOutput]]]: + List[RequestOutput]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - if scheduler_outputs.is_empty(): - return seq_group_metadata_list, scheduler_outputs, [ - RequestOutput.from_seq_group(seq_group) - for seq_group in scheduler_outputs.ignored_seq_groups - ] - return seq_group_metadata_list, scheduler_outputs, None + return seq_group_metadata_list, scheduler_outputs, [ + RequestOutput.from_seq_group(seq_group) + for seq_group in scheduler_outputs.ignored_seq_groups + ] def _check_beam_search_early_stopping( self, @@ -545,10 +543,9 @@ class LLMEngine: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - (seq_group_metadata_list, scheduler_outputs, - early_return) = self._schedule() - if early_return is not None: - return early_return + seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() + if scheduler_outputs.is_empty(): + return ignored # Execute the model. output = self._run_workers( @@ -559,7 +556,7 @@ class LLMEngine: blocks_to_copy=scheduler_outputs.blocks_to_copy, ) - return self._process_model_outputs(output, scheduler_outputs) + return self._process_model_outputs(output, scheduler_outputs) + ignored def _log_system_stats( self,