[BugFix] Fix recovery logic for sequence group (#2186)

This commit is contained in:
Woosuk Kwon 2023-12-20 21:52:37 -08:00 committed by GitHub
parent 3a4fd5ca59
commit a1b9cb2a34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 8 deletions

View File

@ -103,7 +103,7 @@ class BlockSpaceManager:
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks) num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None: if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
@ -122,7 +122,7 @@ class BlockSpaceManager:
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same
# prompt. # prompt.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
# Allocate new physical token blocks that will store the prompt tokens. # Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = [] block_table: BlockTable = []
@ -137,7 +137,7 @@ class BlockSpaceManager:
block_table.append(block) block_table.append(block)
# Assign the block table for each sequence. # Assign the block table for each sequence.
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy() self.block_tables[seq.seq_id] = block_table.copy()
def can_append_slot(self, seq_group: SequenceGroup) -> bool: def can_append_slot(self, seq_group: SequenceGroup) -> bool:

View File

@ -139,15 +139,17 @@ class Scheduler:
while self.waiting: while self.waiting:
seq_group = self.waiting[0] seq_group = self.waiting[0]
assert seq_group.num_seqs() == 1, ( waiting_seqs = seq_group.get_seqs(
status=SequenceStatus.WAITING)
assert len(waiting_seqs) == 1, (
"Waiting sequence group should have only one prompt " "Waiting sequence group should have only one prompt "
"sequence.") "sequence.")
num_prompt_tokens = seq_group.get_seqs()[0].get_len() num_prompt_tokens = waiting_seqs[0].get_len()
if num_prompt_tokens > self.prompt_limit: if num_prompt_tokens > self.prompt_limit:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds limit of {self.prompt_limit}") f" and exceeds limit of {self.prompt_limit}")
for seq in seq_group.get_seqs(): for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.pop(0)
@ -161,7 +163,7 @@ class Scheduler:
logger.warning( logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long" f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager") f" and exceeds the capacity of block_manager")
for seq in seq_group.get_seqs(): for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.pop(0)
@ -317,7 +319,7 @@ class Scheduler:
def _allocate(self, seq_group: SequenceGroup) -> None: def _allocate(self, seq_group: SequenceGroup) -> None:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
def _append_slot( def _append_slot(