diff --git a/cacheflow/core/block_manager.py b/cacheflow/core/block_manager.py index 2d471e0b..8f1295bc 100644 --- a/cacheflow/core/block_manager.py +++ b/cacheflow/core/block_manager.py @@ -80,7 +80,7 @@ class BlockSpaceManager: def can_allocate(self, seq_group: SequenceGroup) -> bool: # FIXME(woosuk): Here we assume that all sequences in the group share # the same prompt. This may not be true for preempted sequences. - seq = seq_group.seqs[0] + seq = seq_group.get_seqs()[0] num_required_blocks = len(seq.logical_token_blocks) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() # Use watermark to avoid frequent cache eviction. @@ -88,7 +88,7 @@ class BlockSpaceManager: def allocate(self, seq_group: SequenceGroup) -> None: # NOTE: Here we assume that all sequences in the group have the same prompt. - seq = seq_group.seqs[0] + seq = seq_group.get_seqs()[0] # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] @@ -99,7 +99,7 @@ class BlockSpaceManager: block_table.append(block) # Assign the block table for each sequence. - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): self.block_tables[seq.seq_id] = block_table.copy() def can_append_slot(self, seq_group: SequenceGroup) -> bool: @@ -147,7 +147,7 @@ class BlockSpaceManager: # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): if seq.status == SequenceStatus.FINISHED: continue block_table = self.block_tables[seq.seq_id] @@ -168,7 +168,7 @@ class BlockSpaceManager: def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): if seq.status == SequenceStatus.FINISHED: continue new_block_table: BlockTable = [] @@ -199,7 +199,7 @@ class BlockSpaceManager: def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): if seq.status == SequenceStatus.FINISHED: continue new_block_table: BlockTable = [] diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index ccbd9396..08f85573 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -73,8 +73,6 @@ class Scheduler: self.waiting: List[SequenceGroup] = [] # Sequence groups in the RUNNING state. self.running: List[SequenceGroup] = [] - # Mapping: request_id -> num_steps. - self.num_steps: Dict[str, int] = {} # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] @@ -84,7 +82,6 @@ class Scheduler: def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. - assert seq_group.request_id not in self.num_steps self.waiting.append(seq_group) def has_unfinished_seqs(self) -> bool: @@ -178,7 +175,7 @@ class Scheduler: break # If the number of batched tokens exceeds the limit, stop. - num_prompt_tokens = seq_group.seqs[0].get_len() + num_prompt_tokens = seq_group.get_seqs()[0].get_len() if (num_batched_tokens + num_prompt_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -278,15 +275,8 @@ class Scheduler: ) -> List[SequenceGroup]: # Update the running sequences and free blocks. for seq_group in self.running: - request_id = seq_group.request_id - self.num_steps[request_id] += 1 - stop_token_ids = seq_group.sampling_params.stop_token_ids - - # Process beam search results before processing the next tokens. - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - + # Process beam search results before processing the new tokens. + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): output = seq_outputs[seq.seq_id] if seq.seq_id != output.parent_seq_id: # The sequence is a fork of the parent sequence (beam search). @@ -297,43 +287,27 @@ class Scheduler: parent_seq.fork(seq) self.block_manager.fork(parent_seq, seq) - # Process the next tokens. - for seq in seq_group.seqs: - if seq.status == SequenceStatus.FINISHED: - continue - + # Process the new tokens. + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): # Append a new token to the sequence. output = seq_outputs[seq.seq_id] seq.append_token(output.output_token, output.logprobs) + return self.running.copy() - # Check if the sequence has generated a stop token. - if output.output_token in stop_token_ids: - self._free_seq(seq) - continue + def free_seq(self, seq: Sequence) -> None: + seq.status = SequenceStatus.FINISHED + self.block_manager.free(seq) - # Check if the sequence has reached the maximum number of steps. - max_num_steps = seq_group.sampling_params.max_tokens - if self.num_steps[request_id] == max_num_steps: - self._free_seq(seq) - continue - - # Update the running sequences. - updated = self.running.copy() - running: List[SequenceGroup] = [] - for seq_group in self.running: - if seq_group.is_finished(): - self._free_seq_group(seq_group) - else: - running.append(seq_group) - self.running = running - return updated + def free_finished_seq_groups(self) -> None: + self.running = [ + seq_group for seq_group in self.running + if not seq_group.is_finished() + ] def _allocate(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) - for seq in seq_group.seqs: + for seq in seq_group.get_seqs(): seq.status = SequenceStatus.RUNNING - if seq_group.request_id not in self.num_steps: - self.num_steps[seq_group.request_id] = 0 def _append_slot( self, @@ -403,13 +377,6 @@ class Scheduler: self._swap_out(seq_group, blocks_to_swap_out) self.swapped.append(seq_group) - def _free_seq(self, seq: Sequence) -> None: - seq.status = SequenceStatus.FINISHED - self.block_manager.free(seq) - - def _free_seq_group(self, seq_group: SequenceGroup) -> None: - del self.num_steps[seq_group.request_id] - def _swap_in( self, seq_group: SequenceGroup, diff --git a/cacheflow/entrypoints/fastapi_server.py b/cacheflow/entrypoints/fastapi_server.py index ed8afda1..26882f10 100644 --- a/cacheflow/entrypoints/fastapi_server.py +++ b/cacheflow/entrypoints/fastapi_server.py @@ -123,6 +123,7 @@ if __name__ == "__main__": parallel_config = server_configs[2] distributed_init_method, stage_devices = initialize_cluster(parallel_config) - server = FastAPIServer( - args.use_ray, *server_configs, distributed_init_method, stage_devices) + server = FastAPIServer(args.use_ray, *server_configs, + distributed_init_method, stage_devices, + log_stats=not args.disable_log_stats) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/cacheflow/model_executor/layers/sampler.py b/cacheflow/model_executor/layers/sampler.py index 84cccadb..1c3187c0 100644 --- a/cacheflow/model_executor/layers/sampler.py +++ b/cacheflow/model_executor/layers/sampler.py @@ -283,20 +283,20 @@ def _sample_from_prompt( ) -> List[int]: if sampling_params.use_beam_search: # Beam search. - beam_width = sampling_params.n + beam_width = sampling_params.best_of _, next_token_ids = torch.topk(prob, beam_width) next_token_ids = next_token_ids.tolist() elif sampling_params.temperature == 0.0: # Greedy sampling. - assert sampling_params.n == 1 + assert sampling_params.best_of == 1 next_token_id = torch.argmax(prob) next_token_ids = [next_token_id.item()] else: # Random sampling. - # Sample n tokens for the prompt. - n = sampling_params.n + # Sample `best_of` tokens for the prompt. + num_seqs = sampling_params.best_of next_token_ids = torch.multinomial( - prob, num_samples=n, replacement=True) + prob, num_samples=num_seqs, replacement=True) next_token_ids = next_token_ids.tolist() return next_token_ids @@ -308,7 +308,7 @@ def _sample_from_generation_tokens( seq_logprobs: List[float], sampling_params: SamplingParams, ) -> Tuple[List[int], List[int]]: - # NOTE(woosuk): sampling_params.n can be greater than + # NOTE(woosuk): sampling_params.best_of can be greater than # len(seq_ids) because some sequences in the group might have # been already terminated. if sampling_params.use_beam_search: @@ -372,7 +372,7 @@ def _sample( seq_ids, sampling_params = seq_group if i < input_metadata.num_prompts: # Generate the next tokens for a prompt input. - assert len(seq_ids) == sampling_params.n + assert len(seq_ids) == sampling_params.best_of prob = probs[idx] logprob = logprobs[idx] idx += 1 @@ -397,7 +397,7 @@ def _sample( # Sample the next tokens. seq_logprobs = [ - input_metadata.seq_data[seq_id].cumulative_logprobs + input_metadata.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids] parent_seq_ids, next_token_ids = _sample_from_generation_tokens( seq_ids, prob, logprob, seq_logprobs, sampling_params) diff --git a/cacheflow/outputs.py b/cacheflow/outputs.py index f6ba3678..0b4dcabe 100644 --- a/cacheflow/outputs.py +++ b/cacheflow/outputs.py @@ -1,6 +1,4 @@ -from typing import Dict, List, Union - -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing import Dict, List from cacheflow.sequence import SequenceGroup @@ -9,20 +7,23 @@ class CompletionOutput: def __init__( self, + index: int, text: str, token_ids: List[int], - cumulative_logprobs: float, + cumulative_logprob: float, logprobs: List[Dict[int, float]], ) -> None: + self.index = index self.text = text self.token_ids = token_ids - self.cumulative_logprobs = cumulative_logprobs + self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs def __repr__(self) -> str: - return (f"CompletionOutput(output={self.text!r}, " + return (f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " f"token_ids={self.token_ids}, " - f"cumulative_logprobs={self.cumulative_logprobs}, " + f"cumulative_logprob={self.cumulative_logprob}, " f"logprobs={self.logprobs})") @@ -43,31 +44,32 @@ class RequestOutput: self.done = done @staticmethod - def from_seq_group( - seq_group: SequenceGroup, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> "RequestOutput": - outputs: List[CompletionOutput] = [] + def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": + # Get the top-n sequences. + n = seq_group.sampling_params.n seqs = seq_group.get_seqs() - for seq in seqs: - output_token_ids = seq.data.output_token_ids - output_str = tokenizer.decode(output_token_ids, - skip_special_tokens=True) - seq_logprobs = seq.data.cumulative_logprobs + assert n <= len(seqs) + sorted_seqs = sorted( + seqs, key=lambda seq: seq.get_cumulative_logprob(), reverse=True) + top_n_seqs = sorted_seqs[:n] + # Create the outputs. + outputs: List[CompletionOutput] = [] + for seq in top_n_seqs: logprobs = seq.output_logprobs if seq_group.sampling_params.logprobs == 0: # NOTE: We need to take care of this case because the sequence # always has the logprobs of the sampled tokens even if the # logprobs are not requested. logprobs = {} - output = CompletionOutput(output_str, output_token_ids, - seq_logprobs, logprobs) + output = CompletionOutput(seqs.index(seq), seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), logprobs) outputs.append(output) # Every sequence in the sequence group should have the same prompt. - prompt = seqs[0].prompt - prompt_token_ids = seqs[0].data.prompt_token_ids + prompt = top_n_seqs[0].prompt + prompt_token_ids = top_n_seqs[0].data.prompt_token_ids return RequestOutput(seq_group.request_id, prompt, prompt_token_ids, outputs, seq_group.is_finished()) diff --git a/cacheflow/sampling_params.py b/cacheflow/sampling_params.py index 219ba438..0ce772a9 100644 --- a/cacheflow/sampling_params.py +++ b/cacheflow/sampling_params.py @@ -1,5 +1,5 @@ """Sampling parameters for text generation.""" -from typing import Set +from typing import List, Optional, Union class SamplingParams: @@ -10,8 +10,12 @@ class SamplingParams: In addition, we support beam search, which is not supported by OpenAI. Args: - n: Number of output sequences to generate from the given prompt. This is - regarded as the beam width when using beam search. + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`. presence_penalty: Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat @@ -28,7 +32,10 @@ class SamplingParams: top_k: Integer that controls the number of top tokens to consider. Set to -1 to consider all tokens. use_beam_search: Whether to use beam search instead of sampling. - stop_token_ids: Set of token IDs that indicate the end of a sequence. + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. logprobs: Number of log probabilities to return per output token. """ @@ -36,24 +43,28 @@ class SamplingParams: def __init__( self, n: int = 1, + best_of: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, use_beam_search: bool = False, - stop_token_ids: Set[int] = set(), + stop: Union[str, List[str]] = [], + ignore_eos: bool = False, max_tokens: int = 16, logprobs: int = 0, ) -> None: self.n = n + self.best_of = best_of if best_of is not None else n self.presence_penalty = presence_penalty self.frequency_penalty = frequency_penalty self.temperature = temperature self.top_p = top_p self.top_k = top_k self.use_beam_search = use_beam_search - self.stop_token_ids = stop_token_ids + self.stop = [stop] if isinstance(stop, str) else list(stop) + self.ignore_eos = ignore_eos self.max_tokens = max_tokens self.logprobs = logprobs @@ -67,6 +78,9 @@ class SamplingParams: def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") + if self.best_of < self.n: + raise ValueError(f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}.") if not -2.0 <= self.presence_penalty <= 2.0: raise ValueError("presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}.") @@ -89,8 +103,9 @@ class SamplingParams: f"logprobs must be non-negative, got {self.logprobs}.") def _verity_beam_search(self) -> None: - if self.n == 1: - raise ValueError("n must be greater than 1 when using beam search.") + if self.best_of == 1: + raise ValueError("best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}.") if self.temperature > 0.0: raise ValueError("temperature must be 0 when using beam search.") if self.top_p < 1.0: @@ -99,8 +114,9 @@ class SamplingParams: raise ValueError("top_k must be -1 when using beam search.") def _verify_greedy_sampling(self) -> None: - if self.n > 1: - raise ValueError("n must be 1 when using greedy sampling.") + if self.best_of > 1: + raise ValueError("best_of must be 1 when using greedy sampling." + f"Got {self.best_of}.") if self.top_p < 1.0: raise ValueError("top_p must be 1 when using greedy sampling.") if self.top_k != -1: @@ -108,12 +124,14 @@ class SamplingParams: def __repr__(self) -> str: return (f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " f"presence_penalty={self.presence_penalty}, " f"frequency_penalty={self.frequency_penalty}, " f"temperature={self.temperature}, " f"top_p={self.top_p}, " f"top_k={self.top_k}," f"use_beam_search={self.use_beam_search}, " - f"stop_token_ids={self.stop_token_ids}, " + f"stop={self.stop}, " + f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " f"logprobs={self.logprobs})") diff --git a/cacheflow/sequence.py b/cacheflow/sequence.py index f2c0e11c..b2c19fae 100644 --- a/cacheflow/sequence.py +++ b/cacheflow/sequence.py @@ -22,11 +22,18 @@ class SequenceData: self.prompt_token_ids = prompt_token_ids self.output_token_ids: List[int] = [] - self.cumulative_logprobs = 0.0 + self.cumulative_logprob = 0.0 + + def append_token(self, token_id: int, logprob: float) -> None: + self.output_token_ids.append(token_id) + self.cumulative_logprob += logprob def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) + def get_output_len(self) -> int: + return len(self.output_token_ids) + def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids @@ -37,9 +44,9 @@ class SequenceData: def __repr__(self) -> str: return (f"SequenceData(" - f"prompt={self.prompt}, " f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids})") + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob})") class Sequence: @@ -57,6 +64,7 @@ class Sequence: self.data = SequenceData(prompt_token_ids) self.output_logprobs: List[Dict[int, float]] = [] + self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. @@ -88,18 +96,26 @@ class Sequence: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) self.output_logprobs.append(logprobs) - self.data.output_token_ids.append(token_id) - self.data.cumulative_logprobs += logprobs[token_id] + self.data.append_token(token_id, logprobs[token_id]) def get_len(self) -> int: return self.data.get_len() + def get_output_len(self) -> int: + return self.data.get_output_len() + def get_token_ids(self) -> List[int]: return self.data.get_token_ids() def get_last_token_id(self) -> int: return self.data.get_last_token_id() + def get_output_token_ids(self) -> List[int]: + return self.data.output_token_ids + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + def fork(self, child_seq: 'Sequence') -> 'Sequence': child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index d5c5d918..4cc4a228 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -13,7 +13,7 @@ from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.server.tokenizer_utils import get_tokenizer -from cacheflow.sequence import Sequence, SequenceGroup +from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Counter from cacheflow.worker.worker import Worker @@ -49,7 +49,6 @@ class LLMServer: self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats - self._verify_args() self.tokenizer = get_tokenizer(model_config.model) @@ -124,15 +123,11 @@ class LLMServer: # Create the sequences. block_size = self.cache_config.block_size seqs: List[Sequence] = [] - for _ in range(sampling_params.n): + for _ in range(sampling_params.best_of): seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) seqs.append(seq) - # FIXME(woosuk) - # Add the EOS token to the stop token list. - sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id) - # Create the sequence group. seq_group = SequenceGroup(request_id, seqs, sampling_params, arrival_time) @@ -157,18 +152,65 @@ class LLMServer: blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, blocks_to_copy=scheduler_outputs.blocks_to_copy, ) - # Update the scheduler. - updated_seq_groups = self.scheduler.update(output) + # Update the scheduler with the model outputs. + seq_groups = self.scheduler.update(output) + + # Decode the sequences. + self._decode_sequences(seq_groups) + # Stop the sequences that meet the stopping criteria. + self._stop_sequences(seq_groups) + # Free the finished sequence groups. + self.scheduler.free_finished_seq_groups() # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in updated_seq_groups: - # TODO(woosuk): Batch-decode the outputs for speedup. - request_output = RequestOutput.from_seq_group(seq_group, - self.tokenizer) + for seq_group in seq_groups: + request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) return request_outputs + def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None: + # Batch-decode the sequence outputs. + seqs: List[Sequence] = [] + for seq_group in seq_groups: + seqs.extend(seq_group.get_seqs(status=SequenceStatus.RUNNING)) + output_tokens_per_seq = [] + for seq in seqs: + output_tokens_per_seq.append(seq.get_output_token_ids()) + output_texts = self.tokenizer.batch_decode(output_tokens_per_seq, + skip_special_tokens=True) + # Update the sequences with the output texts. + for seq, output_text in zip(seqs, output_texts): + seq.output_text = output_text + + def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None: + # Stop the sequences. + for seq_group in seq_groups: + sampling_params = seq_group.sampling_params + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + # Check if the sequence has generated a stop string. + stopped = False + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] + self.scheduler.free_seq(seq) + stopped = True + break + if stopped: + continue + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + self.scheduler.free_seq(seq) + continue + # Check if the sequence has generated the EOS token. + if not sampling_params.ignore_eos: + if seq.get_last_token_id() == self.tokenizer.eos_token_id: + self.scheduler.free_seq(seq) + continue + def _run_workers( self, method: str, diff --git a/examples/simple_server.py b/examples/simple_server.py index 62714855..ace2980e 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -15,9 +15,9 @@ def main(args: argparse.Namespace): ("To be or not to be,", SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ("What is the meaning of life?", - SamplingParams(n=2, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), + SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), ("It is only with the heart that one can see rightly", - SamplingParams(n=3, use_beam_search=True, temperature=0.0)), + SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), ] # Run the server.