[Misc] Add a wrapper for torch.inference_mode (#6618)
This commit is contained in:
parent
c9eef37f32
commit
42de2cefcb
@ -2,7 +2,9 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||
|
||||
current_platform: Optional[Platform]
|
||||
|
||||
@ -12,7 +14,10 @@ if torch.version.cuda is not None:
|
||||
elif torch.version.hip is not None:
|
||||
from .rocm import RocmPlatform
|
||||
current_platform = RocmPlatform()
|
||||
elif is_tpu():
|
||||
from .tpu import TpuPlatform
|
||||
current_platform = TpuPlatform()
|
||||
else:
|
||||
current_platform = None
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
CUDA = enum.auto()
|
||||
ROCM = enum.auto()
|
||||
TPU = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
@ -16,6 +20,23 @@ class Platform:
|
||||
def is_rocm(self) -> bool:
|
||||
return self._enum == PlatformEnum.ROCM
|
||||
|
||||
def is_tpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.TPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
"""A device-specific wrapper of `torch.inference_mode`.
|
||||
|
||||
This wrapper is recommended because some hardware backends such as TPU
|
||||
do not support `torch.inference_mode`. In such a case, they will fall
|
||||
back to `torch.no_grad` by overriding this method.
|
||||
"""
|
||||
return torch.inference_mode(mode=True)
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
||||
17
vllm/platforms/tpu.py
Normal file
17
vllm/platforms/tpu.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||
raise RuntimeError("TPU does not have device capability.")
|
||||
|
||||
@staticmethod
|
||||
def inference_mode():
|
||||
return torch.no_grad()
|
||||
@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||
SequenceGroupMetadata)
|
||||
|
||||
@ -163,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.inference_mode()
|
||||
@current_platform.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: T,
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||
SamplerOutput)
|
||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
@ -53,7 +54,7 @@ class WorkerBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@torch.inference_mode()
|
||||
@current_platform.inference_mode()
|
||||
def start_worker_execution_loop(self) -> None:
|
||||
"""Execute model loop in parallel worker.
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user