From a1b9cb2a3469e4e682e741af9a0f91e16923205d Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 20 Dec 2023 21:52:37 -0800 Subject: [PATCH] [BugFix] Fix recovery logic for sequence group (#2186) --- vllm/core/block_manager.py | 6 +++--- vllm/core/scheduler.py | 12 +++++++----- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 8b26319b..3bde0059 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -103,7 +103,7 @@ class BlockSpaceManager: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.get_seqs()[0] + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -122,7 +122,7 @@ class BlockSpaceManager: def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same # prompt. - seq = seq_group.get_seqs()[0] + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] @@ -137,7 +137,7 @@ class BlockSpaceManager: block_table.append(block) # Assign the block table for each sequence. - for seq in seq_group.get_seqs(): + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() def can_append_slot(self, seq_group: SequenceGroup) -> bool: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc..398585a8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -139,15 +139,17 @@ class Scheduler: while self.waiting: seq_group = self.waiting[0] - assert seq_group.num_seqs() == 1, ( + waiting_seqs = seq_group.get_seqs( + status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_prompt_tokens = seq_group.get_seqs()[0].get_len() + num_prompt_tokens = waiting_seqs[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" f" and exceeds limit of {self.prompt_limit}") - for seq in seq_group.get_seqs(): + for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) self.waiting.pop(0) @@ -161,7 +163,7 @@ class Scheduler: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" f" and exceeds the capacity of block_manager") - for seq in seq_group.get_seqs(): + for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) self.waiting.pop(0) @@ -317,7 +319,7 @@ class Scheduler: def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) - for seq in seq_group.get_seqs(): + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): seq.status = SequenceStatus.RUNNING def _append_slot(