[Misc] Add a wrapper for torch.inference_mode (#6618)

This commit is contained in:
Woosuk Kwon 2024-07-21 18:43:11 -07:00 committed by GitHub
parent c9eef37f32
commit 42de2cefcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 4 deletions

View File

@ -2,7 +2,9 @@ from typing import Optional
import torch
from .interface import Platform, PlatformEnum
from vllm.utils import is_tpu
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
current_platform: Optional[Platform]
@ -12,7 +14,10 @@ if torch.version.cuda is not None:
elif torch.version.hip is not None:
from .rocm import RocmPlatform
current_platform = RocmPlatform()
elif is_tpu():
from .tpu import TpuPlatform
current_platform = TpuPlatform()
else:
current_platform = None
current_platform = UnspecifiedPlatform()
__all__ = ['Platform', 'PlatformEnum', 'current_platform']

View File

@ -1,10 +1,14 @@
import enum
from typing import Tuple
import torch
class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
TPU = enum.auto()
UNSPECIFIED = enum.auto()
class Platform:
@ -16,6 +20,23 @@ class Platform:
def is_rocm(self) -> bool:
return self._enum == PlatformEnum.ROCM
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
@staticmethod
def inference_mode():
"""A device-specific wrapper of `torch.inference_mode`.
This wrapper is recommended because some hardware backends such as TPU
do not support `torch.inference_mode`. In such a case, they will fall
back to `torch.no_grad` by overriding this method.
"""
return torch.inference_mode(mode=True)
class UnspecifiedPlatform(Platform):
_enum = PlatformEnum.UNSPECIFIED

17
vllm/platforms/tpu.py Normal file
View File

@ -0,0 +1,17 @@
from typing import Tuple
import torch
from .interface import Platform, PlatformEnum
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise RuntimeError("TPU does not have device capability.")
@staticmethod
def inference_mode():
return torch.no_grad()

View File

@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
import torch
from vllm.platforms import current_platform
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -163,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@torch.inference_mode()
@current_platform.inference_mode()
def execute_model(
self,
model_input: T,

View File

@ -9,6 +9,7 @@ import torch
from vllm.distributed import broadcast_tensor_dict, get_pp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.platforms import current_platform
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.utils import (enable_trace_function_call_for_thread,
@ -53,7 +54,7 @@ class WorkerBase(ABC):
"""
raise NotImplementedError
@torch.inference_mode()
@current_platform.inference_mode()
def start_worker_execution_loop(self) -> None:
"""Execute model loop in parallel worker.