174 lines
6.8 KiB
Python
174 lines
6.8 KiB
Python
import inspect
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.compilation.levels import CompilationLevel
|
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
|
from vllm.logger import init_logger
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils import supports_dynamo
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def support_torch_compile(
|
|
cls: Optional[type] = None,
|
|
dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
|
|
"""
|
|
A decorator to add support for compiling the forward method of a class.
|
|
|
|
Usage 1: use directly as a decorator without arguments:
|
|
|
|
```python
|
|
@support_torch_compile
|
|
class MyModel(nn.Module):
|
|
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
|
...
|
|
```
|
|
|
|
Usage 2: use as a decorator with arguments:
|
|
|
|
```python
|
|
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
|
|
class MyModel(nn.Module):
|
|
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
|
|
...
|
|
```
|
|
|
|
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
|
|
dimensions of the argument. The dynamic dimensions can be either a single
|
|
integer or a list of integers.
|
|
|
|
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
|
|
of the `forward` method, based on the following default rules:
|
|
|
|
- if the argument is annotated as `torch.Tensor` or
|
|
`Optional[torch.Tensor]`, the first dimension will be
|
|
marked as dynamic.
|
|
- if the argument is annotated as `IntermediateTensors`, the first
|
|
dimension of all the tensors in the intermediate tensors
|
|
will be marked as dynamic.
|
|
|
|
During runtime, when we actually mark dimensions of tensors,
|
|
it depends on the value of arguments:
|
|
|
|
- if it is a single integer, the corresponding dimension of the argument
|
|
will be marked as dynamic.
|
|
- if it is `None`, ignored.
|
|
- if it is `IntermediateTensors`, all the tensors in the intermediate
|
|
tensors will be marked as dynamic.
|
|
- otherwise, it will raise an error.
|
|
|
|
NOTE: if an argument is `None`, it should always be passed as `None` during
|
|
the lifetime of the model, otherwise, it cannot be captured as a single
|
|
computation graph.
|
|
"""
|
|
|
|
def cls_decorator_helper(cls: type):
|
|
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
|
|
# to avoid too much indentation for `_support_torch_compile``
|
|
if not hasattr(cls, 'forward'):
|
|
raise TypeError("decorated class should have a forward method.")
|
|
sig = inspect.signature(cls.forward)
|
|
inferred_dynamic_arg_dims = dynamic_arg_dims
|
|
if inferred_dynamic_arg_dims is None:
|
|
inferred_dynamic_arg_dims = {}
|
|
for k, v in sig.parameters.items():
|
|
if v.annotation in [
|
|
torch.Tensor, Optional[torch.Tensor],
|
|
IntermediateTensors, Optional[IntermediateTensors]
|
|
]:
|
|
inferred_dynamic_arg_dims[k] = 0
|
|
|
|
logger.debug(("Inferred dynamic dimensions for "
|
|
"forward method of %s: %s"), cls,
|
|
list(inferred_dynamic_arg_dims.keys()))
|
|
|
|
if len(inferred_dynamic_arg_dims) == 0:
|
|
raise ValueError(
|
|
"No dynamic dimensions found in the forward method of "
|
|
f"{cls}. Please provide dynamic_arg_dims explicitly.")
|
|
|
|
for k in inferred_dynamic_arg_dims:
|
|
if k not in sig.parameters:
|
|
raise ValueError(
|
|
f"Argument {k} not found in the forward method of {cls}")
|
|
return _support_torch_compile(cls, inferred_dynamic_arg_dims)
|
|
|
|
if cls is not None:
|
|
# use `support_torch_compile` as a decorator without arguments
|
|
assert isinstance(cls, type)
|
|
return cls_decorator_helper(cls)
|
|
|
|
return cls_decorator_helper
|
|
|
|
|
|
def _support_torch_compile(cls: type,
|
|
dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
|
|
"""
|
|
A decorator to add support for compiling the forward method of a class.
|
|
"""
|
|
|
|
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
|
# will handle the compilation, so we don't need to do anything here.
|
|
if envs.VLLM_TORCH_COMPILE_LEVEL in [
|
|
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
|
] or not supports_dynamo():
|
|
return cls
|
|
|
|
# take care of method resolution order
|
|
# make sure super().__init__ is called on the base class
|
|
# other than TorchCompileWrapperWithCustomDispatcher
|
|
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
|
|
|
old_init = cls.__init__ # type: ignore
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
old_init(self, *args, **kwargs)
|
|
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
|
|
|
cls.__init__ = __init__ # type: ignore
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# torch.compiler.is_compiling() means we are inside the compilation
|
|
# e.g. TPU has the compilation logic in model runner, so we don't
|
|
# need to compile the model inside.
|
|
if torch.compiler.is_compiling():
|
|
return self.forward(*args, **kwargs)
|
|
|
|
# the first compilation needs to have dynamic shapes marked
|
|
if len(self.compiled_codes) < 1:
|
|
sig = inspect.signature(self.__class__.forward)
|
|
bound_args = sig.bind(self, *args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
for k, dims in dynamic_arg_dims.items():
|
|
arg = bound_args.arguments.get(k)
|
|
if arg is not None:
|
|
if isinstance(arg, torch.Tensor):
|
|
torch._dynamo.mark_dynamic(arg, dims)
|
|
elif isinstance(arg, IntermediateTensors):
|
|
for tensor in arg.tensors.values():
|
|
torch._dynamo.mark_dynamic(tensor, dims)
|
|
else:
|
|
raise ValueError(
|
|
"Unsupported dynamic dimensions"
|
|
f" {dims} for argument {k} with type {type(arg)}.")
|
|
|
|
# if we don't use custom dispatcher, we can directly call the
|
|
# compiled function and let torch.compile handle the dispatching,
|
|
# with the overhead of guard evaluation and recompilation.
|
|
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
|
return self.compiled_callable(*args, **kwargs)
|
|
|
|
# usually, capturing the model once is enough, and then we can
|
|
# dispatch to the compiled code directly, without going through
|
|
# the Dynamo guard mechanism.
|
|
with self.dispatch_to_code(0):
|
|
model_output = self.forward(*args, **kwargs)
|
|
return model_output
|
|
|
|
cls.__call__ = __call__ # type: ignore
|
|
return cls
|