diff --git a/cacheflow/__init__.py b/cacheflow/__init__.py index 113cd26d..03ac4248 100644 --- a/cacheflow/__init__.py +++ b/cacheflow/__init__.py @@ -1,19 +1,15 @@ +from cacheflow.entrypoints.llm import LLM from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ( - add_server_arguments, - create_server_configs_from_args, - initialize_server_from_args, -) +from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.llm_server import LLMServer from cacheflow.server.ray_utils import initialize_cluster __all__ = [ - "RequestOutput", + "LLM", "SamplingParams", + "RequestOutput", "LLMServer", - "add_server_arguments", - "create_server_configs_from_args", - "initialize_server_from_args", + "ServerArgs", "initialize_cluster", ] diff --git a/cacheflow/config.py b/cacheflow/config.py index dda48359..93992b96 100644 --- a/cacheflow/config.py +++ b/cacheflow/config.py @@ -3,6 +3,8 @@ from typing import Optional import torch from transformers import AutoConfig, PretrainedConfig +_GiB = 1 << 30 + class ModelConfig: @@ -70,7 +72,7 @@ class CacheConfig: ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization - self.swap_space = swap_space + self.swap_space_bytes = swap_space * _GiB # Will be set after profiling. self.num_gpu_blocks = None @@ -138,6 +140,8 @@ def _get_and_verify_dtype( else: torch_dtype = config_dtype else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] # Verify the dtype. diff --git a/cacheflow/entrypoints/fastapi_server.py b/cacheflow/entrypoints/fastapi_server.py index 26882f10..f69b82bc 100644 --- a/cacheflow/entrypoints/fastapi_server.py +++ b/cacheflow/entrypoints/fastapi_server.py @@ -12,8 +12,7 @@ import uvicorn from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ( - add_server_arguments, create_server_configs_from_args) +from cacheflow.server.arg_utils import ServerArgs from cacheflow.server.llm_server import LLMServer from cacheflow.server.ray_utils import initialize_cluster @@ -116,10 +115,10 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=10002) - parser = add_server_arguments(parser) + parser = ServerArgs.add_cli_args(parser) args = parser.parse_args() - server_configs = create_server_configs_from_args(args) + server_configs = ServerArgs.from_cli_args(args).create_server_configs() parallel_config = server_configs[2] distributed_init_method, stage_devices = initialize_cluster(parallel_config) diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py new file mode 100644 index 00000000..821b7db9 --- /dev/null +++ b/cacheflow/entrypoints/llm.py @@ -0,0 +1,62 @@ +from typing import List, Optional + +from tqdm import tqdm + +from cacheflow.outputs import RequestOutput +from cacheflow.sampling_params import SamplingParams +from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.llm_server import LLMServer +from cacheflow.utils import Counter + + +class LLM: + + def __init__( + self, + model: str, + tensor_parallel_size: int = 1, + dtype: str = "default", + seed: int = 0, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + server_args = ServerArgs( + model=model, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + seed=seed, + **kwargs, + ) + self.llm_server = LLMServer.from_server_args(server_args) + self.request_counter = Counter() + + def generate( + self, + prompts: List[str], + sampling_params: Optional[SamplingParams] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + if sampling_params is None: + sampling_params = SamplingParams() + # Initialize tqdm. + if use_tqdm: + pbar = tqdm(total=len(prompts), desc="Processed prompts") + + # Add requests to the server. + for prompt in prompts: + request_id = str(next(self.request_counter)) + self.llm_server.add_request(request_id, prompt, sampling_params) + + # Run the server. + outputs: List[RequestOutput] = [] + while self.llm_server.has_unfinished_requests(): + step_outputs = self.llm_server.step() + for output in step_outputs: + if output.done: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + return outputs diff --git a/cacheflow/outputs.py b/cacheflow/outputs.py index 0b4dcabe..84fd7197 100644 --- a/cacheflow/outputs.py +++ b/cacheflow/outputs.py @@ -35,7 +35,7 @@ class RequestOutput: prompt: str, prompt_token_ids: List[int], outputs: List[CompletionOutput], - done: bool = False, + done: bool, ) -> None: self.request_id = request_id self.prompt = prompt @@ -43,8 +43,8 @@ class RequestOutput: self.outputs = outputs self.done = done - @staticmethod - def from_seq_group(seq_group: SequenceGroup) -> "RequestOutput": + @classmethod + def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # Get the top-n sequences. n = seq_group.sampling_params.n seqs = seq_group.get_seqs() @@ -70,8 +70,8 @@ class RequestOutput: # Every sequence in the sequence group should have the same prompt. prompt = top_n_seqs[0].prompt prompt_token_ids = top_n_seqs[0].data.prompt_token_ids - return RequestOutput(seq_group.request_id, prompt, prompt_token_ids, - outputs, seq_group.is_finished()) + return cls(seq_group.request_id, prompt, prompt_token_ids, outputs, + seq_group.is_finished()) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " diff --git a/cacheflow/server/arg_utils.py b/cacheflow/server/arg_utils.py index ce429ada..a4b898dd 100644 --- a/cacheflow/server/arg_utils.py +++ b/cacheflow/server/arg_utils.py @@ -1,74 +1,117 @@ import argparse -from typing import Tuple +import dataclasses +from dataclasses import dataclass +from typing import Optional, Tuple from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from cacheflow.server.llm_server import LLMServer -from cacheflow.server.ray_utils import initialize_cluster - -_GiB = 1 << 30 -def add_server_arguments(parser: argparse.ArgumentParser): - """Shared arguments for CacheFlow servers.""" +@dataclass +class ServerArgs: + model: str + download_dir: Optional[str] = None + use_np_weights: bool = False + use_dummy_weights: bool = False + dtype: str = "default" + seed: int = 0 + use_ray: bool = False + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + block_size: int = 16 + swap_space: int = 4 # GiB + gpu_memory_utilization: float = 0.95 + max_num_batched_tokens: int = 2560 + max_num_seqs: int = 256 + disable_log_stats: bool = False + + def __post_init__(self): + self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser, + ) -> argparse.ArgumentParser: + return _add_server_arguments(parser) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + server_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return server_args + + def create_server_configs( + self, + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + # Initialize the configs. + model_config = ModelConfig( + self.model, self.download_dir, self.use_np_weights, + self.use_dummy_weights, self.dtype, self.seed) + cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, + self.swap_space) + parallel_config = ParallelConfig(self.pipeline_parallel_size, + self.tensor_parallel_size, + self.use_ray) + scheduler_config = SchedulerConfig(self.max_num_batched_tokens, + self.max_num_seqs) + return model_config, cache_config, parallel_config, scheduler_config + + +def _add_server_arguments( + parser: argparse.ArgumentParser, +)-> argparse.ArgumentParser: + """Shared CLI arguments for CacheFlow servers.""" # Model arguments - parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') - parser.add_argument('--download-dir', type=str, default=None, + parser.add_argument('--model', type=str, default='facebook/opt-125m', + help='name or path of the huggingface model to use') + parser.add_argument('--download-dir', type=str, + default=ServerArgs.download_dir, help='directory to download and load the weights, ' 'default to the default cache dir of huggingface') parser.add_argument('--use-np-weights', action='store_true', - help='save a numpy copy of model weights for faster loading') - parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') + help='save a numpy copy of model weights for faster ' + 'loading. This can increase the disk usage by up ' + 'to 2x.') + parser.add_argument('--use-dummy-weights', action='store_true', + help='use dummy values for model weights') # TODO(woosuk): Support FP32. - parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], + parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, + choices=['default', 'half', 'bfloat16'], help=('data type for model weights and activations. ' 'The "default" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.')) # Parallel arguments - parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') + parser.add_argument('--use-ray', action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU') + parser.add_argument('--pipeline-parallel-size', '-pp', type=int, + default=ServerArgs.pipeline_parallel_size, + help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, + default=ServerArgs.tensor_parallel_size, + help='number of tensor parallel replicas') # KV cache arguments - parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') + parser.add_argument('--block-size', type=int, default=ServerArgs.block_size, + choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], + help='token block size') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=0, help='random seed') - parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor') - parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') - parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration') - parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + parser.add_argument('--seed', type=int, default=ServerArgs.seed, + help='random seed') + parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space, + help='CPU swap space size (GiB) per GPU') + parser.add_argument('--gpu-memory-utilization', type=float, + default=ServerArgs.gpu_memory_utilization, + help='the percentage of GPU memory to be used for the ' + 'model executor') + parser.add_argument('--max-num-batched-tokens', type=int, + default=ServerArgs.max_num_batched_tokens, + help='maximum number of batched tokens per iteration') + parser.add_argument('--max-num-seqs', type=int, + default=ServerArgs.max_num_seqs, + help='maximum number of sequences per iteration') + parser.add_argument('--disable-log-stats', action='store_true', + help='disable logging statistics') return parser - - -def create_server_configs_from_args( - args: argparse.Namespace, -) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: - # Post-process the parsed arguments. - args.swap_space = args.swap_space * _GiB - args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens) - - # Initialize the configs. - model_config = ModelConfig( - args.model, args.download_dir, args.use_np_weights, - args.use_dummy_weights, args.dtype, args.seed) - cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization, - args.swap_space) - parallel_config = ParallelConfig(args.pipeline_parallel_size, - args.tensor_parallel_size, args.use_ray) - scheduler_config = SchedulerConfig(args.max_num_batched_tokens, - args.max_num_seqs) - return model_config, cache_config, parallel_config, scheduler_config - - -def initialize_server_from_args(args: argparse.Namespace) -> LLMServer: - server_configs = create_server_configs_from_args(args) - parallel_config = server_configs[2] - - # Initialize the cluster. - distributed_init_method, devices = initialize_cluster(parallel_config) - - # Create the LLM server. - server = LLMServer(*server_configs, distributed_init_method, devices, - log_stats=not args.disable_log_stats) - return server diff --git a/cacheflow/server/llm_server.py b/cacheflow/server/llm_server.py index 4cc4a228..5e01bc69 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/server/llm_server.py @@ -12,6 +12,8 @@ from cacheflow.core.scheduler import Scheduler from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams +from cacheflow.server.arg_utils import ServerArgs +from cacheflow.server.ray_utils import initialize_cluster from cacheflow.server.tokenizer_utils import get_tokenizer from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Counter @@ -30,7 +32,7 @@ class LLMServer: scheduler_config: SchedulerConfig, distributed_init_method: str, stage_devices: List[List[Any]], - log_stats: bool = True, + log_stats: bool, ) -> None: logger.info( "Initializing an LLM server with config: " @@ -90,7 +92,7 @@ class LLMServer: get_all_outputs=True, block_size=self.cache_config.block_size, gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space, + cpu_swap_space=self.cache_config.swap_space_bytes, ) # Since we use a shared centralized controller, we take the minimum @@ -107,6 +109,18 @@ class LLMServer: # Initialize the cache. self._run_workers("init_cache_engine", cache_config=self.cache_config) + @classmethod + def from_server_args(cls, server_args: ServerArgs) -> "LLMServer": + # Create the server configs. + server_configs = server_args.create_server_configs() + parallel_config = server_configs[2] + # Initialize the cluster. + distributed_init_method, devices = initialize_cluster(parallel_config) + # Create the LLM server. + server = cls(*server_configs, distributed_init_method, devices, + log_stats=not server_args.disable_log_stats) + return server + def add_request( self, request_id: str, diff --git a/examples/offline_inference.py b/examples/offline_inference.py new file mode 100644 index 00000000..35fadfd6 --- /dev/null +++ b/examples/offline_inference.py @@ -0,0 +1,23 @@ +from cacheflow import LLM, SamplingParams + + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/simple_server.py b/examples/simple_server.py index ace2980e..781c05f7 100644 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -1,13 +1,13 @@ import argparse import uuid -from cacheflow import (add_server_arguments, initialize_server_from_args, - SamplingParams) +from cacheflow import ServerArgs, LLMServer, SamplingParams def main(args: argparse.Namespace): - # Initialize the server. - server = initialize_server_from_args(args) + # Parse the CLI argument and initialize the server. + server_args = ServerArgs.from_cli_args(args) + server = LLMServer.from_server_args(server_args) # Test the following prompts. test_prompts = [ @@ -39,6 +39,6 @@ def main(args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Simple CacheFlow server.') - parser = add_server_arguments(parser) + parser = ServerArgs.add_cli_args(parser) args = parser.parse_args() main(args)