From 3f92038b990b26fdb9e6a9bccab0e3ec0cdc6aea Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 18 Jun 2023 11:39:35 -0700 Subject: [PATCH] Add comments on swap space (#154) --- benchmarks/benchmark_serving.py | 3 ++- vllm/core/scheduler.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 91819f86..b8d824c7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -3,7 +3,8 @@ On the server side, run one of the following commands: (vLLM backend) python -m vllm.entrypoints.api_server \ - --disable-log-requests --model + --model --swap-space 16 \ + --disable-log-requests (TGI backend) ./launch_hf_server.sh diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 8081c844..500c5ddd 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -409,7 +409,12 @@ class Scheduler: seq_group: SequenceGroup, blocks_to_swap_out: Dict[int, int], ) -> None: - assert self.block_manager.can_swap_out(seq_group) + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):