Implement scheduler.step & Add a threshold for batch size

This commit is contained in:
Woosuk Kwon 2023-02-23 07:54:20 +00:00
parent 501c4bd0cd
commit 331fa0b042

View File

@ -5,6 +5,8 @@ from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import SequenceStatus from cacheflow.sequence import SequenceStatus
_MAX_NUM_BATCHED_TOKENS = 2048
class Scheduler: class Scheduler:
@ -27,8 +29,8 @@ class Scheduler:
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
) )
# Serving sequence groups (FIFO). # Running sequence groups (FIFO).
self.serving: List[SequenceGroup] = [] self.running: List[SequenceGroup] = []
# Mapping: group_id -> num_steps. # Mapping: group_id -> num_steps.
self.num_steps: Dict[int, int] = {} self.num_steps: Dict[int, int] = {}
# Mapping: group_id -> max_num_steps. # Mapping: group_id -> max_num_steps.
@ -54,7 +56,7 @@ class Scheduler:
self.block_manager.allocate(seq_group) self.block_manager.allocate(seq_group)
for seq in seq_group.seqs: for seq in seq_group.seqs:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
self.serving.append(seq_group) self.running.append(seq_group)
# FIXME # FIXME
self.num_steps[seq_group.group_id] = 0 self.num_steps[seq_group.group_id] = 0
@ -73,7 +75,7 @@ class Scheduler:
for seq in seq_group.seqs: for seq in seq_group.seqs:
if seq.status == SequenceStatus.SWAPPED: if seq.status == SequenceStatus.SWAPPED:
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
self.serving.append(seq_group) self.running.append(seq_group)
def _swap_out(self, seq_group: SequenceGroup) -> None: def _swap_out(self, seq_group: SequenceGroup) -> None:
assert self.block_manager.can_swap_out(seq_group) assert self.block_manager.can_swap_out(seq_group)
@ -89,14 +91,14 @@ class Scheduler:
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
# That is, the most recently added sequence group is the first # That is, the most recently added sequence group is the first
# to be swapped out. # to be swapped out.
victim_idx = len(self.serving) - 1 victim_idx = len(self.running) - 1
for i, seq_group in enumerate(self.serving): for i, seq_group in enumerate(self.running):
if i > victim_idx: if i > victim_idx:
# The i-th sequence group has already been swapped out. # The i-th sequence group has already been swapped out.
break break
# OOM. Swap out the victim sequence groups. # OOM. Swap out the victim sequence groups.
while not self.block_manager.can_append(seq_group): while not self.block_manager.can_append(seq_group):
victim_seq_group = self.serving[victim_idx] victim_seq_group = self.running[victim_idx]
self._swap_out(victim_seq_group) self._swap_out(victim_seq_group)
victim_idx -= 1 victim_idx -= 1
if i > victim_idx: if i > victim_idx:
@ -104,7 +106,7 @@ class Scheduler:
break break
else: else:
self._append(seq_group) self._append(seq_group)
self.serving = self.serving[:victim_idx + 1] self.running = self.running[:victim_idx + 1]
# 2. Swap in the swapped sequences if possible. # 2. Swap in the swapped sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
@ -121,16 +123,25 @@ class Scheduler:
# All swapped sequences are swapped in. # All swapped sequences are swapped in.
self.swapped.clear() self.swapped.clear()
num_batched_tokens = sum(
seq_group.num_seqs(status=SequenceStatus.RUNNING)
for seq_group in self.running
)
# 3. Join new sequences if possible. # 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling. # NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a heuristic to control the maximum batch size. # TODO(woosuk): Add a batching policy to control the batch size.
if not self.swapped: if not self.swapped:
# FIXME: Acquire a lock. # FIXME(woosuk): Acquire a lock to protect pending.
for i, seq_group in enumerate(self.pending): for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
if self.block_manager.can_allocate(seq_group): if self.block_manager.can_allocate(seq_group):
if (num_batched_tokens + num_prompt_tokens
<= _MAX_NUM_BATCHED_TOKENS):
self._allocate(seq_group) self._allocate(seq_group)
else: num_batched_tokens += num_prompt_tokens
# FIXME: Consider the race condition. continue
self.pending = self.pending[i:] self.pending = self.pending[i:]
break break
else: else:
@ -141,8 +152,35 @@ class Scheduler:
if self.blocks_to_swap_in: if self.blocks_to_swap_in:
assert not self.blocks_to_swap_out assert not self.blocks_to_swap_out
# Create input data structures.
prompt_tokens: Dict[int, List[int]] = {}
generation_tokens: Dict[int, int] = {}
context_lens: Dict[int, int] = {}
block_tables: Dict[int, List[int]] = {}
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
# NOTE(woosuk): We assume that the number of steps is 0
# for the prompt sequences.
is_prompt = num_steps == 0
for seq in seq_group.seqs:
if seq.status != SequenceStatus.RUNNING:
continue
seq_id = seq.seq_id
block_tables[seq_id] = self.block_manager.get_block_table(seq)
if is_prompt:
prompt_tokens[seq_id] = seq.get_token_ids()
else:
generation_tokens[seq_id] = seq.get_token_ids()[-1]
context_lens[seq_id] = seq.get_len()
# Execute the first stage of the pipeline. # Execute the first stage of the pipeline.
self.controllers[0].execute_stage( self.controllers[0].execute_stage(
prompt_tokens,
generation_tokens,
context_lens,
block_tables,
self.blocks_to_swap_in.copy(), self.blocks_to_swap_in.copy(),
self.blocks_to_swap_out.copy(), self.blocks_to_swap_out.copy(),
self.blocks_to_copy.copy(), self.blocks_to_copy.copy(),
@ -158,7 +196,7 @@ class Scheduler:
next_tokens: Dict[int, Tuple[int, int]], next_tokens: Dict[int, Tuple[int, int]],
) -> None: ) -> None:
# Update the running sequences and free blocks. # Update the running sequences and free blocks.
for seq_group in self.serving: for seq_group in self.running:
group_id = seq_group.group_id group_id = seq_group.group_id
self.num_steps[group_id] += 1 self.num_steps[group_id] += 1
stop_token_ids = self.stop_token_ids[group_id] stop_token_ids = self.stop_token_ids[group_id]
@ -190,14 +228,14 @@ class Scheduler:
self._free_seq(seq) self._free_seq(seq)
continue continue
# Update the serving states. # Update the running sequences.
serving: List[SequenceGroup] = [] running: List[SequenceGroup] = []
for seq_group in self.serving: for seq_group in self.running:
if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs): if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs):
del self.num_steps[seq_group.group_id] del self.num_steps[seq_group.group_id]
del self.max_num_steps[seq_group.group_id] del self.max_num_steps[seq_group.group_id]
del self.stop_token_ids[seq_group.group_id] del self.stop_token_ids[seq_group.group_id]
# TODO: Return the seq_group to the client. # TODO: Return the seq_group to the client.
else: else:
serving.append(seq_group) running.append(seq_group)
self.serving = serving self.running = running