diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index 50cf6366..37e73b21 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -5,6 +5,8 @@ from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceStatus +_MAX_NUM_BATCHED_TOKENS = 2048 + class Scheduler: @@ -27,8 +29,8 @@ class Scheduler: num_cpu_blocks=num_cpu_blocks, ) - # Serving sequence groups (FIFO). - self.serving: List[SequenceGroup] = [] + # Running sequence groups (FIFO). + self.running: List[SequenceGroup] = [] # Mapping: group_id -> num_steps. self.num_steps: Dict[int, int] = {} # Mapping: group_id -> max_num_steps. @@ -54,7 +56,7 @@ class Scheduler: self.block_manager.allocate(seq_group) for seq in seq_group.seqs: seq.status = SequenceStatus.RUNNING - self.serving.append(seq_group) + self.running.append(seq_group) # FIXME self.num_steps[seq_group.group_id] = 0 @@ -73,7 +75,7 @@ class Scheduler: for seq in seq_group.seqs: if seq.status == SequenceStatus.SWAPPED: seq.status = SequenceStatus.RUNNING - self.serving.append(seq_group) + self.running.append(seq_group) def _swap_out(self, seq_group: SequenceGroup) -> None: assert self.block_manager.can_swap_out(seq_group) @@ -89,14 +91,14 @@ class Scheduler: # NOTE: Here we implicitly assume FCFS scheduling. # That is, the most recently added sequence group is the first # to be swapped out. - victim_idx = len(self.serving) - 1 - for i, seq_group in enumerate(self.serving): + victim_idx = len(self.running) - 1 + for i, seq_group in enumerate(self.running): if i > victim_idx: # The i-th sequence group has already been swapped out. break # OOM. Swap out the victim sequence groups. while not self.block_manager.can_append(seq_group): - victim_seq_group = self.serving[victim_idx] + victim_seq_group = self.running[victim_idx] self._swap_out(victim_seq_group) victim_idx -= 1 if i > victim_idx: @@ -104,7 +106,7 @@ class Scheduler: break else: self._append(seq_group) - self.serving = self.serving[:victim_idx + 1] + self.running = self.running[:victim_idx + 1] # 2. Swap in the swapped sequences if possible. # NOTE: Here we implicitly assume FCFS scheduling. @@ -121,18 +123,27 @@ class Scheduler: # All swapped sequences are swapped in. self.swapped.clear() + num_batched_tokens = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running + ) + # 3. Join new sequences if possible. # NOTE: Here we implicitly assume FCFS scheduling. - # TODO(woosuk): Add a heuristic to control the maximum batch size. + # TODO(woosuk): Add a batching policy to control the batch size. if not self.swapped: - # FIXME: Acquire a lock. + # FIXME(woosuk): Acquire a lock to protect pending. for i, seq_group in enumerate(self.pending): + num_prompt_tokens = seq_group.seqs[0].get_len() if self.block_manager.can_allocate(seq_group): - self._allocate(seq_group) - else: - # FIXME: Consider the race condition. - self.pending = self.pending[i:] - break + if (num_batched_tokens + num_prompt_tokens + <= _MAX_NUM_BATCHED_TOKENS): + self._allocate(seq_group) + num_batched_tokens += num_prompt_tokens + continue + + self.pending = self.pending[i:] + break else: self.pending.clear() @@ -141,8 +152,35 @@ class Scheduler: if self.blocks_to_swap_in: assert not self.blocks_to_swap_out + # Create input data structures. + prompt_tokens: Dict[int, List[int]] = {} + generation_tokens: Dict[int, int] = {} + context_lens: Dict[int, int] = {} + block_tables: Dict[int, List[int]] = {} + for seq_group in self.running: + group_id = seq_group.group_id + num_steps = self.num_steps[group_id] + # NOTE(woosuk): We assume that the number of steps is 0 + # for the prompt sequences. + is_prompt = num_steps == 0 + for seq in seq_group.seqs: + if seq.status != SequenceStatus.RUNNING: + continue + + seq_id = seq.seq_id + block_tables[seq_id] = self.block_manager.get_block_table(seq) + if is_prompt: + prompt_tokens[seq_id] = seq.get_token_ids() + else: + generation_tokens[seq_id] = seq.get_token_ids()[-1] + context_lens[seq_id] = seq.get_len() + # Execute the first stage of the pipeline. self.controllers[0].execute_stage( + prompt_tokens, + generation_tokens, + context_lens, + block_tables, self.blocks_to_swap_in.copy(), self.blocks_to_swap_out.copy(), self.blocks_to_copy.copy(), @@ -158,7 +196,7 @@ class Scheduler: next_tokens: Dict[int, Tuple[int, int]], ) -> None: # Update the running sequences and free blocks. - for seq_group in self.serving: + for seq_group in self.running: group_id = seq_group.group_id self.num_steps[group_id] += 1 stop_token_ids = self.stop_token_ids[group_id] @@ -190,14 +228,14 @@ class Scheduler: self._free_seq(seq) continue - # Update the serving states. - serving: List[SequenceGroup] = [] - for seq_group in self.serving: + # Update the running sequences. + running: List[SequenceGroup] = [] + for seq_group in self.running: if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs): del self.num_steps[seq_group.group_id] del self.max_num_steps[seq_group.group_id] del self.stop_token_ids[seq_group.group_id] # TODO: Return the seq_group to the client. else: - serving.append(seq_group) - self.serving = serving + running.append(seq_group) + self.running = running