[Core][Optimization] remove vllm-nccl (#5091)
This commit is contained in:
parent
616e600e0b
commit
5bd3c65072
@ -37,7 +37,6 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- pytest -v -s distributed/test_pynccl_library.py
|
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
|
|||||||
@ -4,7 +4,6 @@
|
|||||||
# Dependencies for NVIDIA GPUs
|
# Dependencies for NVIDIA GPUs
|
||||||
ray >= 2.9
|
ray >= 2.9
|
||||||
nvidia-ml-py # for pynvml package
|
nvidia-ml-py # for pynvml package
|
||||||
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
|
|
||||||
torch == 2.3.0
|
torch == 2.3.0
|
||||||
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
|
||||||
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
|
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
|
||||||
|
|||||||
7
setup.py
7
setup.py
@ -358,11 +358,8 @@ def get_requirements() -> List[str]:
|
|||||||
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
cuda_major, cuda_minor = torch.version.cuda.split(".")
|
||||||
modified_requirements = []
|
modified_requirements = []
|
||||||
for req in requirements:
|
for req in requirements:
|
||||||
if "vllm-nccl-cu12" in req:
|
if ("vllm-flash-attn" in req
|
||||||
req = req.replace("vllm-nccl-cu12",
|
and not (cuda_major == "12" and cuda_minor == "1")):
|
||||||
f"vllm-nccl-cu{cuda_major}")
|
|
||||||
elif ("vllm-flash-attn" in req
|
|
||||||
and not (cuda_major == "12" and cuda_minor == "1")):
|
|
||||||
# vllm-flash-attn is built only for CUDA 12.1.
|
# vllm-flash-attn is built only for CUDA 12.1.
|
||||||
# Skip for other versions.
|
# Skip for other versions.
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -1,43 +0,0 @@
|
|||||||
import multiprocessing
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
|
|
||||||
def target_fn(env, filepath):
|
|
||||||
from vllm.utils import update_environment_variables
|
|
||||||
update_environment_variables(env)
|
|
||||||
from vllm.utils import nccl_integrity_check
|
|
||||||
nccl_integrity_check(filepath)
|
|
||||||
|
|
||||||
|
|
||||||
def test_library_file():
|
|
||||||
# note: don't import vllm.distributed.device_communicators.pynccl
|
|
||||||
# before running this test, otherwise the library file will be loaded
|
|
||||||
# and it might interfere with the test
|
|
||||||
from vllm.utils import find_nccl_library
|
|
||||||
so_file = find_nccl_library()
|
|
||||||
with open(so_file, 'rb') as f:
|
|
||||||
content = f.read()
|
|
||||||
try:
|
|
||||||
# corrupt the library file, should raise an exception
|
|
||||||
with open(so_file, 'wb') as f:
|
|
||||||
f.write(content[:len(content) // 2])
|
|
||||||
p = multiprocessing.Process(target=target_fn, args=({}, so_file))
|
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
assert p.exitcode != 0
|
|
||||||
|
|
||||||
# move the library file to a tmp path
|
|
||||||
# test VLLM_NCCL_SO_PATH
|
|
||||||
fd, path = tempfile.mkstemp()
|
|
||||||
with open(path, 'wb') as f:
|
|
||||||
f.write(content)
|
|
||||||
p = multiprocessing.Process(target=target_fn,
|
|
||||||
args=({
|
|
||||||
"VLLM_NCCL_SO_PATH": path
|
|
||||||
}, path))
|
|
||||||
p.start()
|
|
||||||
p.join()
|
|
||||||
assert p.exitcode == 0
|
|
||||||
finally:
|
|
||||||
with open(so_file, 'wb') as f:
|
|
||||||
f.write(content)
|
|
||||||
@ -28,7 +28,7 @@ import torch
|
|||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import find_nccl_library, nccl_integrity_check
|
from vllm.utils import find_nccl_library
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -188,28 +188,22 @@ class NCCLLibrary:
|
|||||||
so_file = so_file or find_nccl_library()
|
so_file = so_file or find_nccl_library()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# load the library in another process.
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
# if it core dumps, it will not crash the current process
|
lib = ctypes.CDLL(so_file)
|
||||||
nccl_integrity_check(so_file)
|
NCCLLibrary.path_to_library_cache[so_file] = lib
|
||||||
|
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Failed to load NCCL library from %s ."
|
"Failed to load NCCL library from %s ."
|
||||||
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
"It is expected if you are not running on NVIDIA/AMD GPUs."
|
||||||
"Otherwise, the nccl library might not exist, be corrupted "
|
"Otherwise, the nccl library might not exist, be corrupted "
|
||||||
"or it does not support the current platform %s."
|
"or it does not support the current platform %s."
|
||||||
"One solution is to download libnccl2 version 2.18 from "
|
"If you already have the library, please set the "
|
||||||
"https://developer.download.nvidia.com/compute/cuda/repos/ "
|
"environment variable VLLM_NCCL_SO_PATH"
|
||||||
"and extract the libnccl.so.2 file. If you already have the "
|
|
||||||
"library, please set the environment variable VLLM_NCCL_SO_PATH"
|
|
||||||
" to point to the correct nccl library path.", so_file,
|
" to point to the correct nccl library path.", so_file,
|
||||||
platform.platform())
|
platform.platform())
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
|
||||||
lib = ctypes.CDLL(so_file)
|
|
||||||
NCCLLibrary.path_to_library_cache[so_file] = lib
|
|
||||||
self.lib = NCCLLibrary.path_to_library_cache[so_file]
|
|
||||||
|
|
||||||
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
||||||
_funcs = {}
|
_funcs = {}
|
||||||
for func in NCCLLibrary.exported_functions:
|
for func in NCCLLibrary.exported_functions:
|
||||||
|
|||||||
@ -2,7 +2,6 @@ import asyncio
|
|||||||
import datetime
|
import datetime
|
||||||
import enum
|
import enum
|
||||||
import gc
|
import gc
|
||||||
import glob
|
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
@ -565,28 +564,6 @@ def init_cached_hf_modules():
|
|||||||
init_hf_modules()
|
init_hf_modules()
|
||||||
|
|
||||||
|
|
||||||
def nccl_integrity_check(filepath):
|
|
||||||
"""
|
|
||||||
when the library is corrupted, we cannot catch
|
|
||||||
the exception in python. it will crash the process.
|
|
||||||
instead, we use the exit code of `ldd` to check
|
|
||||||
if the library is corrupted. if not, we will return
|
|
||||||
the version of the library.
|
|
||||||
"""
|
|
||||||
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
|
|
||||||
if exit_code != 0:
|
|
||||||
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
|
|
||||||
import ctypes
|
|
||||||
|
|
||||||
nccl = ctypes.CDLL(filepath)
|
|
||||||
version = ctypes.c_int()
|
|
||||||
nccl.ncclGetVersion.restype = ctypes.c_int
|
|
||||||
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
|
|
||||||
result = nccl.ncclGetVersion(ctypes.byref(version))
|
|
||||||
assert result == 0
|
|
||||||
return version.value
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def find_library(lib_name: str) -> str:
|
def find_library(lib_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def find_nccl_library():
|
def find_nccl_library():
|
||||||
|
"""
|
||||||
|
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
|
||||||
|
environment variable, or we find the library file brought by PyTorch.
|
||||||
|
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
|
||||||
|
found by `ctypes` automatically.
|
||||||
|
"""
|
||||||
so_file = envs.VLLM_NCCL_SO_PATH
|
so_file = envs.VLLM_NCCL_SO_PATH
|
||||||
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT
|
|
||||||
|
|
||||||
# check if we have vllm-managed nccl
|
|
||||||
vllm_nccl_path = None
|
|
||||||
if torch.version.cuda is not None:
|
|
||||||
cuda_major = torch.version.cuda.split(".")[0]
|
|
||||||
path = os.path.expanduser(
|
|
||||||
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
|
|
||||||
files = glob.glob(path)
|
|
||||||
vllm_nccl_path = files[0] if files else None
|
|
||||||
|
|
||||||
# manually load the nccl library
|
# manually load the nccl library
|
||||||
if so_file:
|
if so_file:
|
||||||
@ -635,9 +608,9 @@ def find_nccl_library():
|
|||||||
so_file)
|
so_file)
|
||||||
else:
|
else:
|
||||||
if torch.version.cuda is not None:
|
if torch.version.cuda is not None:
|
||||||
so_file = vllm_nccl_path or find_library("libnccl.so.2")
|
so_file = "libnccl.so.2"
|
||||||
elif torch.version.hip is not None:
|
elif torch.version.hip is not None:
|
||||||
so_file = find_library("librccl.so.1")
|
so_file = "librccl.so.1"
|
||||||
else:
|
else:
|
||||||
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
raise ValueError("NCCL only supports CUDA and ROCm backends.")
|
||||||
logger.info("Found nccl from library %s", so_file)
|
logger.info("Found nccl from library %s", so_file)
|
||||||
|
|||||||
@ -121,12 +121,14 @@ class WorkerWrapperBase:
|
|||||||
|
|
||||||
def init_worker(self, *args, **kwargs):
|
def init_worker(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Actual initialization of the worker class, and set up
|
Here we inject some common logic before initializing the worker.
|
||||||
function tracing if required.
|
|
||||||
Arguments are passed to the worker class constructor.
|
Arguments are passed to the worker class constructor.
|
||||||
"""
|
"""
|
||||||
enable_trace_function_call_for_thread()
|
enable_trace_function_call_for_thread()
|
||||||
|
|
||||||
|
# see https://github.com/NVIDIA/nccl/issues/1234
|
||||||
|
os.environ['NCCL_CUMEM_ENABLE'] = '0'
|
||||||
|
|
||||||
mod = importlib.import_module(self.worker_module_name)
|
mod = importlib.import_module(self.worker_module_name)
|
||||||
worker_class = getattr(mod, self.worker_class_name)
|
worker_class = getattr(mod, self.worker_class_name)
|
||||||
self.worker = worker_class(*args, **kwargs)
|
self.worker = worker_class(*args, **kwargs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user