From 5bd3c650721cc5de451f034bcbed37d1a1a4116c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 28 May 2024 22:13:52 -0700 Subject: [PATCH] [Core][Optimization] remove vllm-nccl (#5091) --- .buildkite/test-pipeline.yaml | 1 - requirements-cuda.txt | 1 - setup.py | 7 +-- tests/distributed/test_pynccl_library.py | 43 ------------------- .../device_communicators/pynccl_wrapper.py | 20 +++------ vllm/utils.py | 43 ++++--------------- vllm/worker/worker_base.py | 6 ++- 7 files changed, 21 insertions(+), 100 deletions(-) delete mode 100644 tests/distributed/test_pynccl_library.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 08e132d0..21cbd9ba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -37,7 +37,6 @@ steps: working_dir: "/vllm-workspace/tests" num_gpus: 2 commands: - - pytest -v -s distributed/test_pynccl_library.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py diff --git a/requirements-cuda.txt b/requirements-cuda.txt index acb01640..5109f173 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -4,7 +4,6 @@ # Dependencies for NVIDIA GPUs ray >= 2.9 nvidia-ml-py # for pynvml package -vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library torch == 2.3.0 xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0 diff --git a/setup.py b/setup.py index a66af2c5..b4baebb0 100644 --- a/setup.py +++ b/setup.py @@ -358,11 +358,8 @@ def get_requirements() -> List[str]: cuda_major, cuda_minor = torch.version.cuda.split(".") modified_requirements = [] for req in requirements: - if "vllm-nccl-cu12" in req: - req = req.replace("vllm-nccl-cu12", - f"vllm-nccl-cu{cuda_major}") - elif ("vllm-flash-attn" in req - and not (cuda_major == "12" and cuda_minor == "1")): + if ("vllm-flash-attn" in req + and not (cuda_major == "12" and cuda_minor == "1")): # vllm-flash-attn is built only for CUDA 12.1. # Skip for other versions. continue diff --git a/tests/distributed/test_pynccl_library.py b/tests/distributed/test_pynccl_library.py deleted file mode 100644 index ec60a5ed..00000000 --- a/tests/distributed/test_pynccl_library.py +++ /dev/null @@ -1,43 +0,0 @@ -import multiprocessing -import tempfile - - -def target_fn(env, filepath): - from vllm.utils import update_environment_variables - update_environment_variables(env) - from vllm.utils import nccl_integrity_check - nccl_integrity_check(filepath) - - -def test_library_file(): - # note: don't import vllm.distributed.device_communicators.pynccl - # before running this test, otherwise the library file will be loaded - # and it might interfere with the test - from vllm.utils import find_nccl_library - so_file = find_nccl_library() - with open(so_file, 'rb') as f: - content = f.read() - try: - # corrupt the library file, should raise an exception - with open(so_file, 'wb') as f: - f.write(content[:len(content) // 2]) - p = multiprocessing.Process(target=target_fn, args=({}, so_file)) - p.start() - p.join() - assert p.exitcode != 0 - - # move the library file to a tmp path - # test VLLM_NCCL_SO_PATH - fd, path = tempfile.mkstemp() - with open(path, 'wb') as f: - f.write(content) - p = multiprocessing.Process(target=target_fn, - args=({ - "VLLM_NCCL_SO_PATH": path - }, path)) - p.start() - p.join() - assert p.exitcode == 0 - finally: - with open(so_file, 'wb') as f: - f.write(content) diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 3aa3744d..50d6719f 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -28,7 +28,7 @@ import torch from torch.distributed import ReduceOp from vllm.logger import init_logger -from vllm.utils import find_nccl_library, nccl_integrity_check +from vllm.utils import find_nccl_library logger = init_logger(__name__) @@ -188,28 +188,22 @@ class NCCLLibrary: so_file = so_file or find_nccl_library() try: - # load the library in another process. - # if it core dumps, it will not crash the current process - nccl_integrity_check(so_file) + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] except Exception as e: logger.error( "Failed to load NCCL library from %s ." "It is expected if you are not running on NVIDIA/AMD GPUs." "Otherwise, the nccl library might not exist, be corrupted " "or it does not support the current platform %s." - "One solution is to download libnccl2 version 2.18 from " - "https://developer.download.nvidia.com/compute/cuda/repos/ " - "and extract the libnccl.so.2 file. If you already have the " - "library, please set the environment variable VLLM_NCCL_SO_PATH" + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" " to point to the correct nccl library path.", so_file, platform.platform()) raise e - if so_file not in NCCLLibrary.path_to_dict_mapping: - lib = ctypes.CDLL(so_file) - NCCLLibrary.path_to_library_cache[so_file] = lib - self.lib = NCCLLibrary.path_to_library_cache[so_file] - if so_file not in NCCLLibrary.path_to_dict_mapping: _funcs = {} for func in NCCLLibrary.exported_functions: diff --git a/vllm/utils.py b/vllm/utils.py index c8bc54da..85e045cb 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -2,7 +2,6 @@ import asyncio import datetime import enum import gc -import glob import os import socket import subprocess @@ -565,28 +564,6 @@ def init_cached_hf_modules(): init_hf_modules() -def nccl_integrity_check(filepath): - """ - when the library is corrupted, we cannot catch - the exception in python. it will crash the process. - instead, we use the exit code of `ldd` to check - if the library is corrupted. if not, we will return - the version of the library. - """ - exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null") - if exit_code != 0: - raise RuntimeError(f"Failed to load NCCL library from {filepath} .") - import ctypes - - nccl = ctypes.CDLL(filepath) - version = ctypes.c_int() - nccl.ncclGetVersion.restype = ctypes.c_int - nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)] - result = nccl.ncclGetVersion(ctypes.byref(version)) - assert result == 0 - return version.value - - @lru_cache(maxsize=None) def find_library(lib_name: str) -> str: """ @@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str: def find_nccl_library(): + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ so_file = envs.VLLM_NCCL_SO_PATH - VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT - - # check if we have vllm-managed nccl - vllm_nccl_path = None - if torch.version.cuda is not None: - cuda_major = torch.version.cuda.split(".")[0] - path = os.path.expanduser( - f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*") - files = glob.glob(path) - vllm_nccl_path = files[0] if files else None # manually load the nccl library if so_file: @@ -635,9 +608,9 @@ def find_nccl_library(): so_file) else: if torch.version.cuda is not None: - so_file = vllm_nccl_path or find_library("libnccl.so.2") + so_file = "libnccl.so.2" elif torch.version.hip is not None: - so_file = find_library("librccl.so.1") + so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") logger.info("Found nccl from library %s", so_file) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index dbac1b5b..258f31de 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -121,12 +121,14 @@ class WorkerWrapperBase: def init_worker(self, *args, **kwargs): """ - Actual initialization of the worker class, and set up - function tracing if required. + Here we inject some common logic before initializing the worker. Arguments are passed to the worker class constructor. """ enable_trace_function_call_for_thread() + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + mod = importlib.import_module(self.worker_module_name) worker_class = getattr(mod, self.worker_class_name) self.worker = worker_class(*args, **kwargs)