[Misc] Add a wrapper for torch.inference_mode (#6618)
This commit is contained in:
parent
c9eef37f32
commit
42de2cefcb
@ -2,7 +2,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from vllm.utils import is_tpu
|
||||||
|
|
||||||
|
from .interface import Platform, PlatformEnum, UnspecifiedPlatform
|
||||||
|
|
||||||
current_platform: Optional[Platform]
|
current_platform: Optional[Platform]
|
||||||
|
|
||||||
@ -12,7 +14,10 @@ if torch.version.cuda is not None:
|
|||||||
elif torch.version.hip is not None:
|
elif torch.version.hip is not None:
|
||||||
from .rocm import RocmPlatform
|
from .rocm import RocmPlatform
|
||||||
current_platform = RocmPlatform()
|
current_platform = RocmPlatform()
|
||||||
|
elif is_tpu():
|
||||||
|
from .tpu import TpuPlatform
|
||||||
|
current_platform = TpuPlatform()
|
||||||
else:
|
else:
|
||||||
current_platform = None
|
current_platform = UnspecifiedPlatform()
|
||||||
|
|
||||||
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
__all__ = ['Platform', 'PlatformEnum', 'current_platform']
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class PlatformEnum(enum.Enum):
|
class PlatformEnum(enum.Enum):
|
||||||
CUDA = enum.auto()
|
CUDA = enum.auto()
|
||||||
ROCM = enum.auto()
|
ROCM = enum.auto()
|
||||||
|
TPU = enum.auto()
|
||||||
|
UNSPECIFIED = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
class Platform:
|
class Platform:
|
||||||
@ -16,6 +20,23 @@ class Platform:
|
|||||||
def is_rocm(self) -> bool:
|
def is_rocm(self) -> bool:
|
||||||
return self._enum == PlatformEnum.ROCM
|
return self._enum == PlatformEnum.ROCM
|
||||||
|
|
||||||
|
def is_tpu(self) -> bool:
|
||||||
|
return self._enum == PlatformEnum.TPU
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inference_mode():
|
||||||
|
"""A device-specific wrapper of `torch.inference_mode`.
|
||||||
|
|
||||||
|
This wrapper is recommended because some hardware backends such as TPU
|
||||||
|
do not support `torch.inference_mode`. In such a case, they will fall
|
||||||
|
back to `torch.no_grad` by overriding this method.
|
||||||
|
"""
|
||||||
|
return torch.inference_mode(mode=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UnspecifiedPlatform(Platform):
|
||||||
|
_enum = PlatformEnum.UNSPECIFIED
|
||||||
|
|||||||
17
vllm/platforms/tpu.py
Normal file
17
vllm/platforms/tpu.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
|
|
||||||
|
class TpuPlatform(Platform):
|
||||||
|
_enum = PlatformEnum.TPU
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
||||||
|
raise RuntimeError("TPU does not have device capability.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inference_mode():
|
||||||
|
return torch.no_grad()
|
||||||
@ -5,6 +5,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type,
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
from vllm.sequence import (IntermediateTensors, SamplerOutput,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
|
|
||||||
@ -163,7 +164,7 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.inference_mode()
|
@current_platform.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: T,
|
model_input: T,
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import torch
|
|||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||||||
SamplerOutput)
|
SamplerOutput)
|
||||||
from vllm.utils import (enable_trace_function_call_for_thread,
|
from vllm.utils import (enable_trace_function_call_for_thread,
|
||||||
@ -53,7 +54,7 @@ class WorkerBase(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@torch.inference_mode()
|
@current_platform.inference_mode()
|
||||||
def start_worker_execution_loop(self) -> None:
|
def start_worker_execution_loop(self) -> None:
|
||||||
"""Execute model loop in parallel worker.
|
"""Execute model loop in parallel worker.
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user