vllm/cacheflow/engine/llm_engine.py

322 lines
13 KiB
Python
Raw Normal View History

2023-05-21 04:06:59 +08:00
import time
from typing import Any, List, Optional
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.core.scheduler import Scheduler
2023-06-17 17:25:21 +08:00
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
get_tokenizer)
2023-05-21 04:06:59 +08:00
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
2023-05-21 04:06:59 +08:00
from cacheflow.utils import Counter
from cacheflow.worker.worker import Worker
logger = init_logger(__name__)
class LLMEngine:
2023-06-17 17:25:21 +08:00
"""An LLM engine that receives requests and generates texts.
2023-06-17 17:25:21 +08:00
This is the main class for the CacheFlow LLM engine. It receives requests
from clients and generates texts from the LLM. It includes a tokenizer, a
language model (possibly distributed across multiple GPUs), and GPU memory
space allocated for intermediate states (aka KV cache). This class utilizes
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
2023-06-17 17:25:21 +08:00
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
Args:
model_config: The configuration related to the LLM model.
cache_config: The configuration related to the KV cache memory
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
stage_devices: The list of devices for each stage. Each stage is a list
of (rank, node_resource, device) tuples.
log_stats: Whether to log statistics.
"""
2023-05-21 04:06:59 +08:00
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
distributed_init_method: str,
stage_devices: List[List[DeviceID]],
log_stats: bool,
2023-05-21 04:06:59 +08:00
) -> None:
logger.info(
2023-06-17 17:25:21 +08:00
"Initializing an LLM engine with config: "
2023-05-21 04:06:59 +08:00
f"model={model_config.model!r}, "
f"dtype={model_config.dtype}, "
f"use_dummy_weights={model_config.use_dummy_weights}, "
f"download_dir={model_config.download_dir!r}, "
f"use_np_weights={model_config.use_np_weights}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"seed={model_config.seed})"
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config
self.cache_config = cache_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self._verify_args()
self.tokenizer = get_tokenizer(model_config.model)
self.seq_counter = Counter()
# Create the parallel GPU workers.
self.workers: List[Worker] = []
assert len(stage_devices) == 1, "Only support one stage for now."
for rank, node_resource, _ in stage_devices[0]:
worker_cls = Worker
if self.parallel_config.worker_use_ray:
2023-05-21 04:06:59 +08:00
worker_cls = ray.remote(
num_cpus=0,
num_gpus=1,
resources={node_resource: 1e-5},
)(worker_cls).remote
worker = worker_cls(
model_config,
parallel_config,
scheduler_config,
rank,
distributed_init_method,
)
self.workers.append(worker)
# Profile the memory usage and initialize the cache.
self._init_cache()
# Create the scheduler.
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
self.cache_config.verify_with_parallel_config(self.parallel_config)
2023-05-21 04:06:59 +08:00
def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache."""
2023-05-21 04:06:59 +08:00
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers(
"profile_num_available_blocks",
get_all_outputs=True,
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
2023-05-21 04:06:59 +08:00
)
# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
# FIXME(woosuk): Change to debug log.
logger.info(f'# GPU blocks: {num_gpu_blocks}, '
f'# CPU blocks: {num_cpu_blocks}')
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
@classmethod
2023-06-17 17:25:21 +08:00
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(parallel_config)
2023-06-17 17:25:21 +08:00
# Create the LLM engine.
engine = cls(*engine_configs, distributed_init_method, devices,
log_stats=not engine_args.disable_log_stats)
return engine
2023-05-21 04:06:59 +08:00
def add_request(
self,
request_id: str,
2023-06-05 03:52:41 +08:00
prompt: Optional[str],
2023-05-21 04:06:59 +08:00
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
) -> None:
2023-06-17 17:25:21 +08:00
"""Add a request to the engine's request pool.
The request is added to the request pool and will be processed by the
2023-06-17 17:25:21 +08:00
scheduler as `engine.step()` is called. The exact scheduling policy is
determined by the scheduler.
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
sampling_params: The sampling parameters for text generation.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current time.
"""
2023-05-21 04:06:59 +08:00
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
2023-06-05 03:52:41 +08:00
assert prompt is not None
2023-05-21 04:06:59 +08:00
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.
block_size = self.cache_config.block_size
seqs: List[Sequence] = []
for _ in range(sampling_params.best_of):
2023-05-21 04:06:59 +08:00
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
seqs.append(seq)
# Create the sequence group.
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def abort_request(self, request_id: str) -> None:
"""Aborts a request with the given ID.
Args:
request_id: The ID of the request to abort.
"""
self.scheduler.abort_seq_group(request_id)
def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
2023-05-21 04:06:59 +08:00
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
2023-05-21 04:06:59 +08:00
return self.scheduler.has_unfinished_seqs()
def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
2023-06-17 17:25:21 +08:00
This function performs one decoding iteration of the engine. It first
schedules the sequences to be executed in the next iteration and the
token blocks to be swapped in/out/copy. Then, it executes the model
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
2023-05-21 04:06:59 +08:00
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
# Nothing to do.
return []
# Execute the model.
output = self._run_workers(
"execute_model",
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
2023-05-21 04:06:59 +08:00
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
2023-05-21 04:06:59 +08:00
request_outputs.append(request_output)
return request_outputs
def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
seq.output_tokens.append(new_token)
seq.output_text = new_output_text
def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Check if the sequence has generated a stop string.
stopped = False
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
2023-05-24 12:39:50 +08:00
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
2023-05-24 12:39:50 +08:00
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
2023-05-24 12:39:50 +08:00
self.scheduler.free_seq(seq,
SequenceStatus.FINISHED_STOPPED)
continue
2023-05-21 04:06:59 +08:00
def _run_workers(
self,
method: str,
get_all_outputs: bool = False,
*args,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
2023-05-21 04:06:59 +08:00
all_outputs = []
for worker in self.workers:
executor = getattr(worker, method)
if self.parallel_config.worker_use_ray:
2023-05-21 04:06:59 +08:00
executor = executor.remote
2023-05-24 12:39:50 +08:00
2023-05-21 04:06:59 +08:00
output = executor(*args, **kwargs)
all_outputs.append(output)
2023-05-24 12:39:50 +08:00
if self.parallel_config.worker_use_ray:
2023-05-21 04:06:59 +08:00
all_outputs = ray.get(all_outputs)
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output