diff --git a/cacheflow/master/frontend.py b/cacheflow/master/frontend.py new file mode 100644 index 00000000..a34f754b --- /dev/null +++ b/cacheflow/master/frontend.py @@ -0,0 +1,56 @@ +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer + +from cacheflow.sampling_params import SamplingParams +from cacheflow.sequence import Sequence +from cacheflow.sequence import SequenceGroup +from cacheflow.utils import Counter + + +class Frontend: + + def __init__( + self, + model_name: str, + block_size: int, + ) -> None: + self.block_size = block_size + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.seq_group_counter = Counter() + self.seq_counter = Counter() + self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = [] + + def query( + self, + prompt: str, + sampling_params: Optional[SamplingParams] = None, + ) -> None: + if sampling_params is None: + sampling_params = SamplingParams() + token_ids: List[int] = self.tokenizer.encode(prompt) + + seqs: List[Sequence] = [] + for _ in range(sampling_params.n): + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, token_ids, block_size=self.block_size) + seqs.append(seq) + + group_id = next(self.seq_group_counter) + seq_group = SequenceGroup(group_id, seqs) + self.inputs.append((seq_group, sampling_params)) + + def get_inputs(self) -> List[Tuple[SequenceGroup, SamplingParams]]: + inputs = self.inputs + self.inputs = [] + return inputs + + def print_response( + self, + seq_group: SequenceGroup, + ) -> None: + for seq in seq_group.seqs: + token_ids = seq.get_token_ids() + output = self.tokenizer.decode(token_ids, skip_special_tokens=True) + print(f'Seq {seq.seq_id}: {output}') diff --git a/cacheflow/master/scheduler.py b/cacheflow/master/scheduler.py index c5df8d49..7b125002 100644 --- a/cacheflow/master/scheduler.py +++ b/cacheflow/master/scheduler.py @@ -1,6 +1,8 @@ from typing import Dict, List, Tuple from cacheflow.master.block_manager import BlockSpaceManager +from cacheflow.master.frontend import Frontend +from cacheflow.sampling_params import SamplingParams from cacheflow.sequence import Sequence from cacheflow.sequence import SequenceGroup from cacheflow.sequence import SequenceStatus @@ -12,11 +14,13 @@ class Scheduler: def __init__( self, + frontend: Frontend, controllers: List, block_size: int, num_gpu_blocks: int, num_cpu_blocks: int, ) -> None: + self.frontend = frontend self.controllers = controllers self.block_size = block_size self.num_gpu_blocks = num_gpu_blocks @@ -33,16 +37,20 @@ class Scheduler: self.running: List[SequenceGroup] = [] # Mapping: group_id -> num_steps. self.num_steps: Dict[int, int] = {} - # Mapping: group_id -> max_num_steps. - self.max_num_steps: Dict[int, int] = {} - # Mapping: group_id -> stop_token_ids. - self.stop_token_ids: Dict[int, List[int]] = {} + # Mapping: group_id -> sampling params. + self.sampling_params: Dict[int, SamplingParams] = {} # Swapped sequence groups (LIFO). self.swapped: List[SequenceGroup] = [] # Pending sequence groups (FIFO). self.pending: List[SequenceGroup] = [] + def _fetch_inputs(self) -> None: + inputs = self.frontend.get_inputs() + for seq_group, sampling_params in inputs: + self.pending.append(seq_group) + self.sampling_params[seq_group.group_id] = sampling_params + def _free_seq(self, seq: Sequence) -> None: seq.status = SequenceStatus.FINISHED self.block_manager.free(seq) @@ -145,6 +153,7 @@ class Scheduler: # TODO(woosuk): Add a batching policy to control the batch size. if not self.swapped: # FIXME(woosuk): Acquire a lock to protect pending. + self._fetch_inputs() for i, seq_group in enumerate(self.pending): num_prompt_tokens = seq_group.seqs[0].get_len() if self.block_manager.can_allocate(seq_group): @@ -205,7 +214,7 @@ class Scheduler: for seq_group in self.running: group_id = seq_group.group_id self.num_steps[group_id] += 1 - stop_token_ids = self.stop_token_ids[group_id] + stop_token_ids = self.sampling_params[group_id].stop_token_ids for seq in seq_group.seqs: if seq.status == SequenceStatus.FINISHED: @@ -230,24 +239,22 @@ class Scheduler: continue # Check if the sequence has reached the maximum number of steps. - if self.num_steps[group_id] == self.max_num_steps[group_id]: + max_num_steps = self.sampling_params[group_id].max_num_steps + if self.num_steps[group_id] == max_num_steps: self._free_seq(seq) continue # Update the running sequences. running: List[SequenceGroup] = [] for seq_group in self.running: - if all(seq.status == SequenceStatus.FINISHED for seq in seq_group.seqs): - del self.num_steps[seq_group.group_id] - del self.max_num_steps[seq_group.group_id] - del self.stop_token_ids[seq_group.group_id] - # TODO: Return the seq_group to the client. - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m') - for seq in seq_group.seqs: - token_ids = seq.get_token_ids() - output = tokenizer.decode(token_ids, skip_special_tokens=True) - print(f'Seq {seq.seq_id}: {output}') + if seq_group.is_finished(): + self._return(seq_group) else: running.append(seq_group) self.running = running + + def _return(self, seq_group: SequenceGroup) -> None: + group_id = seq_group.group_id + del self.num_steps[group_id] + del self.sampling_params[group_id] + self.frontend.print_response(seq_group)