diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbadc0e0..460d9907 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -506,7 +506,9 @@ class ModelRunner: "use '--enforce-eager' in the CLI.") logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode.") + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. @@ -519,9 +521,15 @@ class ModelRunner: context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. - for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + for batch_size in reversed(batch_size_capture_list): # Create dummy input_metadata. input_metadata = InputMetadata( is_prompt=False,