vllm/vllm/worker/worker_base.py

158 lines
5.5 KiB
Python

import importlib
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Set, Tuple
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)
logger = init_logger(__name__)
class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for
different hardware.
"""
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
memory allocations.
"""
raise NotImplementedError
@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
swappable CPU KV cache.
The implementation may run profiling or other heuristics to determine
the size of caches.
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
are blocks that are "active" on the device and can be appended to.
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
appended to.
"""
raise NotImplementedError
@abstractmethod
def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Initialize the KV cache with the given size in blocks.
"""
raise NotImplementedError
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
raise NotImplementedError
@abstractmethod
def get_cache_block_size_bytes(self) -> int:
"""Return the size of a single cache block, in bytes. Used in
speculative decoding.
"""
raise NotImplementedError
@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError
@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def pin_lora(self, lora_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def list_loras(self) -> Set[int]:
raise NotImplementedError
class LoraNotSupportedWorkerBase(WorkerBase):
"""Partial implementation of WorkerBase that raises exceptions when LoRA
methods are invoked.
"""
def add_lora(self, lora_request: LoRARequest) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def remove_lora(self, lora_id: int) -> bool:
raise ValueError(f"{type(self)} does not support LoRA")
def pin_lora(self, lora_id: int) -> bool:
return ValueError(
f"{type(self)} does not support LoRA") # type: ignore
def list_loras(self) -> Set[int]:
raise ValueError(f"{type(self)} does not support LoRA")
class WorkerWrapperBase:
"""
The whole point of this class is to lazily initialize the worker.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
"""
def __init__(self,
worker_module_name: str,
worker_class_name: str,
trust_remote_code: bool = False) -> None:
self.worker_module_name = worker_module_name
self.worker_class_name = worker_class_name
self.worker = None
if trust_remote_code:
# note: lazy import to avoid importing torch before initializing
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
del os.environ[key]
update_environment_variables(envs)
def init_worker(self, *args, **kwargs):
"""
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()
# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'
mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
def execute_method(self, method, *args, **kwargs):
try:
target = self if self.worker is None else self.worker
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,
# exceptions in the rest worker may cause deadlock in rpc like ray
# see https://github.com/vllm-project/vllm/issues/3455
# print the error and inform the user to solve the error
msg = (f"Error executing method {method}. "
"This might cause deadlock in distributed execution.")
logger.exception(msg)
raise e