From 593e79e7337f7fd9e92b7554dabdff96769dbf15 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Fri, 26 Jul 2024 23:15:20 -0600 Subject: [PATCH] [Bugfix] torch.set_num_threads() in multiproc_gpu_executor (#6802) [Bugfix] Use torch.set_num_threads() to configure parallelism in multiproc_gpu_executor (#6802) Signed-off-by: Travis Johnson --- vllm/executor/multiproc_gpu_executor.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 19f7a497..e1e92958 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -6,6 +6,8 @@ import weakref from functools import partial from typing import Any, List, Optional +import torch + from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker @@ -45,10 +47,23 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): # Disable torch async compiling which won't work with daemonic processes os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - # Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU - # contention amongst the shards - if "OMP_NUM_THREADS" not in os.environ: - os.environ["OMP_NUM_THREADS"] = "1" + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) # workaround for https://github.com/vllm-project/vllm/issues/6103 if world_size > 1: