[Core][Optimization] remove vllm-nccl (#5091)

This commit is contained in:
youkaichao 2024-05-28 22:13:52 -07:00 committed by GitHub
parent 616e600e0b
commit 5bd3c65072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 21 additions and 100 deletions

View File

@ -37,7 +37,6 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- pytest -v -s distributed/test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py

View File

@ -4,7 +4,6 @@
# Dependencies for NVIDIA GPUs
ray >= 2.9
nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0

View File

@ -358,10 +358,7 @@ def get_requirements() -> List[str]:
cuda_major, cuda_minor = torch.version.cuda.split(".")
modified_requirements = []
for req in requirements:
if "vllm-nccl-cu12" in req:
req = req.replace("vllm-nccl-cu12",
f"vllm-nccl-cu{cuda_major}")
elif ("vllm-flash-attn" in req
if ("vllm-flash-attn" in req
and not (cuda_major == "12" and cuda_minor == "1")):
# vllm-flash-attn is built only for CUDA 12.1.
# Skip for other versions.

View File

@ -1,43 +0,0 @@
import multiprocessing
import tempfile
def target_fn(env, filepath):
from vllm.utils import update_environment_variables
update_environment_variables(env)
from vllm.utils import nccl_integrity_check
nccl_integrity_check(filepath)
def test_library_file():
# note: don't import vllm.distributed.device_communicators.pynccl
# before running this test, otherwise the library file will be loaded
# and it might interfere with the test
from vllm.utils import find_nccl_library
so_file = find_nccl_library()
with open(so_file, 'rb') as f:
content = f.read()
try:
# corrupt the library file, should raise an exception
with open(so_file, 'wb') as f:
f.write(content[:len(content) // 2])
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
p.start()
p.join()
assert p.exitcode != 0
# move the library file to a tmp path
# test VLLM_NCCL_SO_PATH
fd, path = tempfile.mkstemp()
with open(path, 'wb') as f:
f.write(content)
p = multiprocessing.Process(target=target_fn,
args=({
"VLLM_NCCL_SO_PATH": path
}, path))
p.start()
p.join()
assert p.exitcode == 0
finally:
with open(so_file, 'wb') as f:
f.write(content)

View File

@ -28,7 +28,7 @@ import torch
from torch.distributed import ReduceOp
from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
from vllm.utils import find_nccl_library
logger = init_logger(__name__)
@ -188,28 +188,22 @@ class NCCLLibrary:
so_file = so_file or find_nccl_library()
try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs = {}
for func in NCCLLibrary.exported_functions:

View File

@ -2,7 +2,6 @@ import asyncio
import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
@ -565,28 +564,6 @@ def init_cached_hf_modules():
init_hf_modules()
def nccl_integrity_check(filepath):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
if exit_code != 0:
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
import ctypes
nccl = ctypes.CDLL(filepath)
version = ctypes.c_int()
nccl.ncclGetVersion.restype = ctypes.c_int
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
result = nccl.ncclGetVersion(ctypes.byref(version))
assert result == 0
return version.value
@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
"""
@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str:
def find_nccl_library():
"""
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
so_file = envs.VLLM_NCCL_SO_PATH
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None
# manually load the nccl library
if so_file:
@ -635,9 +608,9 @@ def find_nccl_library():
so_file)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2")
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = find_library("librccl.so.1")
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info("Found nccl from library %s", so_file)

View File

@ -121,12 +121,14 @@ class WorkerWrapperBase:
def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class, and set up
function tracing if required.
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)