Fix hanging in the scheduler caused by long prompts (#1534)
This commit is contained in:
parent
f5a37c6c6c
commit
3d4ceb292c
@ -1,4 +1,5 @@
|
|||||||
"""A block manager that manages token blocks."""
|
"""A block manager that manages token blocks."""
|
||||||
|
import enum
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
from vllm.block import PhysicalTokenBlock
|
from vllm.block import PhysicalTokenBlock
|
||||||
@ -54,6 +55,20 @@ class BlockAllocator:
|
|||||||
BlockTable = List[PhysicalTokenBlock]
|
BlockTable = List[PhysicalTokenBlock]
|
||||||
|
|
||||||
|
|
||||||
|
class AllocStatus(enum.Enum):
|
||||||
|
"""Result for BlockSpaceManager.can_allocate
|
||||||
|
|
||||||
|
1. Ok: seq_group can be allocated now.
|
||||||
|
2. Later: seq_group cannot be allocated.
|
||||||
|
The capacity of allocator is larger than seq_group required.
|
||||||
|
3. Never: seq_group can never be allocated.
|
||||||
|
The seq_group is too large to allocated in GPU.
|
||||||
|
"""
|
||||||
|
OK = enum.auto()
|
||||||
|
LATER = enum.auto()
|
||||||
|
NEVER = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class BlockSpaceManager:
|
class BlockSpaceManager:
|
||||||
"""Manages the mapping between logical and physical token blocks."""
|
"""Manages the mapping between logical and physical token blocks."""
|
||||||
|
|
||||||
@ -86,7 +101,7 @@ class BlockSpaceManager:
|
|||||||
# Mapping: seq_id -> BlockTable.
|
# Mapping: seq_id -> BlockTable.
|
||||||
self.block_tables: Dict[int, BlockTable] = {}
|
self.block_tables: Dict[int, BlockTable] = {}
|
||||||
|
|
||||||
def can_allocate(self, seq_group: SequenceGroup) -> bool:
|
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||||
# the same prompt. This may not be true for preempted sequences.
|
# the same prompt. This may not be true for preempted sequences.
|
||||||
seq = seq_group.get_seqs()[0]
|
seq = seq_group.get_seqs()[0]
|
||||||
@ -95,9 +110,15 @@ class BlockSpaceManager:
|
|||||||
num_required_blocks = min(num_required_blocks,
|
num_required_blocks = min(num_required_blocks,
|
||||||
self.block_sliding_window)
|
self.block_sliding_window)
|
||||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||||
|
|
||||||
# Use watermark to avoid frequent cache eviction.
|
# Use watermark to avoid frequent cache eviction.
|
||||||
return (num_free_gpu_blocks - num_required_blocks >=
|
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||||
self.watermark_blocks)
|
self.watermark_blocks):
|
||||||
|
return AllocStatus.NEVER
|
||||||
|
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||||
|
return AllocStatus.OK
|
||||||
|
else:
|
||||||
|
return AllocStatus.LATER
|
||||||
|
|
||||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||||
# NOTE: Here we assume that all sequences in the group have the same
|
# NOTE: Here we assume that all sequences in the group have the same
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import time
|
|||||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, SchedulerConfig
|
from vllm.config import CacheConfig, SchedulerConfig
|
||||||
from vllm.core.block_manager import BlockSpaceManager
|
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
||||||
from vllm.core.policy import PolicyFactory
|
from vllm.core.policy import PolicyFactory
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||||
@ -154,8 +154,18 @@ class Scheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# If the sequence group cannot be allocated, stop.
|
# If the sequence group cannot be allocated, stop.
|
||||||
if not self.block_manager.can_allocate(seq_group):
|
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||||
|
if can_allocate == AllocStatus.LATER:
|
||||||
break
|
break
|
||||||
|
elif can_allocate == AllocStatus.NEVER:
|
||||||
|
logger.warning(
|
||||||
|
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||||
|
f" and exceeds the capacity of block_manager")
|
||||||
|
for seq in seq_group.get_seqs():
|
||||||
|
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||||
|
ignored_seq_groups.append(seq_group)
|
||||||
|
self.waiting.pop(0)
|
||||||
|
continue
|
||||||
|
|
||||||
# If the number of batched tokens exceeds the limit, stop.
|
# If the number of batched tokens exceeds the limit, stop.
|
||||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user