[Core][Optimization] remove vllm-nccl (#5091)
This commit is contained in:
parent
616e600e0b
commit
5bd3c65072
@ -37,7 +37,6 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pynccl_library.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
|
||||
@ -4,7 +4,6 @@
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray >= 2.9
|
||||
nvidia-ml-py # for pynvml package
|
||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
||||
torch == 2.3.0
|
||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
|
||||
|
||||
5
setup.py
5
setup.py
@ -358,10 +358,7 @@ def get_requirements() -> List[str]:
|
||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||
modified_requirements = []
|
||||
for req in requirements:
|
||||
if "vllm-nccl-cu12" in req:
|
||||
req = req.replace("vllm-nccl-cu12",
|
||||
f"vllm-nccl-cu{cuda_major}")
|
||||
elif ("vllm-flash-attn" in req
|
||||
if ("vllm-flash-attn" in req
|
||||
and not (cuda_major == "12" and cuda_minor == "1")):
|
||||
# vllm-flash-attn is built only for CUDA 12.1.
|
||||
# Skip for other versions.
|
||||
|
||||
@ -1,43 +0,0 @@
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
|
||||
|
||||
def target_fn(env, filepath):
|
||||
from vllm.utils import update_environment_variables
|
||||
update_environment_variables(env)
|
||||
from vllm.utils import nccl_integrity_check
|
||||
nccl_integrity_check(filepath)
|
||||
|
||||
|
||||
def test_library_file():
|
||||
# note: don't import vllm.distributed.device_communicators.pynccl
|
||||
# before running this test, otherwise the library file will be loaded
|
||||
# and it might interfere with the test
|
||||
from vllm.utils import find_nccl_library
|
||||
so_file = find_nccl_library()
|
||||
with open(so_file, 'rb') as f:
|
||||
content = f.read()
|
||||
try:
|
||||
# corrupt the library file, should raise an exception
|
||||
with open(so_file, 'wb') as f:
|
||||
f.write(content[:len(content) // 2])
|
||||
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
|
||||
p.start()
|
||||
p.join()
|
||||
assert p.exitcode != 0
|
||||
|
||||
# move the library file to a tmp path
|
||||
# test VLLM_NCCL_SO_PATH
|
||||
fd, path = tempfile.mkstemp()
|
||||
with open(path, 'wb') as f:
|
||||
f.write(content)
|
||||
p = multiprocessing.Process(target=target_fn,
|
||||
args=({
|
||||
"VLLM_NCCL_SO_PATH": path
|
||||
}, path))
|
||||
p.start()
|
||||
p.join()
|
||||
assert p.exitcode == 0
|
||||
finally:
|
||||
with open(so_file, 'wb') as f:
|
||||
f.write(content)
|
||||
@ -28,7 +28,7 @@ import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
||||
from vllm.utils import find_nccl_library
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -188,28 +188,22 @@ class NCCLLibrary:
|
||||
so_file = so_file or find_nccl_library()
|
||||
|
||||
try:
|
||||
# load the library in another process.
|
||||
# if it core dumps, it will not crash the current process
|
||||
nccl_integrity_check(so_file)
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to load NCCL library from %s ."
|
||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||
"Otherwise, the nccl library might not exist, be corrupted "
|
||||
"or it does not support the current platform %s."
|
||||
"One solution is to download libnccl2 version 2.18 from "
|
||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
||||
"and extract the libnccl.so.2 file. If you already have the "
|
||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
||||
"If you already have the library, please set the "
|
||||
"environment variable VLLM_NCCL_SO_PATH"
|
||||
" to point to the correct nccl library path.", so_file,
|
||||
platform.platform())
|
||||
raise e
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
lib = ctypes.CDLL(so_file)
|
||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||
|
||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||
_funcs = {}
|
||||
for func in NCCLLibrary.exported_functions:
|
||||
|
||||
@ -2,7 +2,6 @@ import asyncio
|
||||
import datetime
|
||||
import enum
|
||||
import gc
|
||||
import glob
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
@ -565,28 +564,6 @@ def init_cached_hf_modules():
|
||||
init_hf_modules()
|
||||
|
||||
|
||||
def nccl_integrity_check(filepath):
|
||||
"""
|
||||
when the library is corrupted, we cannot catch
|
||||
the exception in python. it will crash the process.
|
||||
instead, we use the exit code of `ldd` to check
|
||||
if the library is corrupted. if not, we will return
|
||||
the version of the library.
|
||||
"""
|
||||
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
|
||||
if exit_code != 0:
|
||||
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
|
||||
import ctypes
|
||||
|
||||
nccl = ctypes.CDLL(filepath)
|
||||
version = ctypes.c_int()
|
||||
nccl.ncclGetVersion.restype = ctypes.c_int
|
||||
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
||||
result = nccl.ncclGetVersion(ctypes.byref(version))
|
||||
assert result == 0
|
||||
return version.value
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def find_library(lib_name: str) -> str:
|
||||
"""
|
||||
@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str:
|
||||
|
||||
|
||||
def find_nccl_library():
|
||||
"""
|
||||
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
|
||||
environment variable, or we find the library file brought by PyTorch.
|
||||
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
|
||||
found by `ctypes` automatically.
|
||||
"""
|
||||
so_file = envs.VLLM_NCCL_SO_PATH
|
||||
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
||||
|
||||
# check if we have vllm-managed nccl
|
||||
vllm_nccl_path = None
|
||||
if torch.version.cuda is not None:
|
||||
cuda_major = torch.version.cuda.split(".")[0]
|
||||
path = os.path.expanduser(
|
||||
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
|
||||
files = glob.glob(path)
|
||||
vllm_nccl_path = files[0] if files else None
|
||||
|
||||
# manually load the nccl library
|
||||
if so_file:
|
||||
@ -635,9 +608,9 @@ def find_nccl_library():
|
||||
so_file)
|
||||
else:
|
||||
if torch.version.cuda is not None:
|
||||
so_file = vllm_nccl_path or find_library("libnccl.so.2")
|
||||
so_file = "libnccl.so.2"
|
||||
elif torch.version.hip is not None:
|
||||
so_file = find_library("librccl.so.1")
|
||||
so_file = "librccl.so.1"
|
||||
else:
|
||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||
logger.info("Found nccl from library %s", so_file)
|
||||
|
||||
@ -121,12 +121,14 @@ class WorkerWrapperBase:
|
||||
|
||||
def init_worker(self, *args, **kwargs):
|
||||
"""
|
||||
Actual initialization of the worker class, and set up
|
||||
function tracing if required.
|
||||
Here we inject some common logic before initializing the worker.
|
||||
Arguments are passed to the worker class constructor.
|
||||
"""
|
||||
enable_trace_function_call_for_thread()
|
||||
|
||||
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||
|
||||
mod = importlib.import_module(self.worker_module_name)
|
||||
worker_class = getattr(mod, self.worker_class_name)
|
||||
self.worker = worker_class(*args, **kwargs)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user