Fix hanging in the scheduler caused by long prompts (#1534)

This commit is contained in:
陈序 2023-11-21 08:06:49 +08:00 committed by GitHub
parent f5a37c6c6c
commit 3d4ceb292c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 5 deletions

View File

@ -1,4 +1,5 @@
"""A block manager that manages token blocks.""" """A block manager that manages token blocks."""
import enum
from typing import Dict, List, Optional, Set, Tuple from typing import Dict, List, Optional, Set, Tuple
from vllm.block import PhysicalTokenBlock from vllm.block import PhysicalTokenBlock
@ -54,6 +55,20 @@ class BlockAllocator:
BlockTable = List[PhysicalTokenBlock] BlockTable = List[PhysicalTokenBlock]
class AllocStatus(enum.Enum):
"""Result for BlockSpaceManager.can_allocate
1. Ok: seq_group can be allocated now.
2. Later: seq_group cannot be allocated.
The capacity of allocator is larger than seq_group required.
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""
OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
class BlockSpaceManager: class BlockSpaceManager:
"""Manages the mapping between logical and physical token blocks.""" """Manages the mapping between logical and physical token blocks."""
@ -86,7 +101,7 @@ class BlockSpaceManager:
# Mapping: seq_id -> BlockTable. # Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {} self.block_tables: Dict[int, BlockTable] = {}
def can_allocate(self, seq_group: SequenceGroup) -> bool: def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share # FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences. # the same prompt. This may not be true for preempted sequences.
seq = seq_group.get_seqs()[0] seq = seq_group.get_seqs()[0]
@ -95,9 +110,15 @@ class BlockSpaceManager:
num_required_blocks = min(num_required_blocks, num_required_blocks = min(num_required_blocks,
self.block_sliding_window) self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
# Use watermark to avoid frequent cache eviction. # Use watermark to avoid frequent cache eviction.
return (num_free_gpu_blocks - num_required_blocks >= if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks) self.watermark_blocks):
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
else:
return AllocStatus.LATER
def allocate(self, seq_group: SequenceGroup) -> None: def allocate(self, seq_group: SequenceGroup) -> None:
# NOTE: Here we assume that all sequences in the group have the same # NOTE: Here we assume that all sequences in the group have the same

View File

@ -3,7 +3,7 @@ import time
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager from vllm.core.block_manager import AllocStatus, BlockSpaceManager
from vllm.core.policy import PolicyFactory from vllm.core.policy import PolicyFactory
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
@ -154,8 +154,18 @@ class Scheduler:
continue continue
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
if not self.block_manager.can_allocate(seq_group): can_allocate = self.block_manager.can_allocate(seq_group)
if can_allocate == AllocStatus.LATER:
break break
elif can_allocate == AllocStatus.NEVER:
logger.warning(
f"Input prompt ({num_prompt_tokens} tokens) is too long"
f" and exceeds the capacity of block_manager")
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
new_seq_lens = seq_lens + [num_prompt_tokens] new_seq_lens = seq_lens + [num_prompt_tokens]