[4/N][torch.compile] clean up set_torch_compile_backend (#10401)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
47826cacf0
commit
51bb12d17b
@ -2,15 +2,14 @@ import copy
|
||||
import dataclasses
|
||||
import operator
|
||||
from contextlib import ExitStack
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Union)
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||
|
||||
@ -684,14 +683,3 @@ class PiecewiseBackend:
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
|
||||
|
||||
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
backend_str = "eager"
|
||||
return backend_str
|
||||
assert level == CompilationLevel.PIECEWISE
|
||||
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
return VllmBackend(compilation_config)
|
||||
|
||||
@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
# choose the compile backend
|
||||
|
||||
# if the user has set the backend, use it
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend()
|
||||
if backend is None:
|
||||
from vllm.compilation.backends import select_default_backend
|
||||
backend = select_default_backend(compilation_level)
|
||||
from vllm.plugins import get_current_vllm_config
|
||||
backend = get_current_vllm_config(
|
||||
).compilation_config.init_backend()
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
|
||||
@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
identity, print_warning_once)
|
||||
identity, print_warning_once, resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
|
||||
- 1: dynamo as is.
|
||||
- 2: dynamo once.
|
||||
- 3: piecewise compilation.
|
||||
- backend: the backend for compilation. It needs to be a string.
|
||||
- "" (empty string): use the default backend.
|
||||
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
|
||||
- "full.module.name": a qualified name which can be used to import the backend function.
|
||||
We use string to avoid serialization issues when using compilation in a distributed setting.
|
||||
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
|
||||
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
|
||||
- custom_ops: fine-grained control over which custom ops to enable/disable.
|
||||
Use 'all' to enable all, 'none' to disable all.
|
||||
Also specify a list of custom op names to enable (prefixed with a '+'),
|
||||
@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
|
||||
certain small batchsizes, where inductor is good at optimizing.
|
||||
""" # noqa
|
||||
level: int = 0
|
||||
backend: str = ""
|
||||
custom_ops: List[str] = Field(default_factory=list)
|
||||
|
||||
use_inductor: bool = True
|
||||
@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func
|
||||
|
||||
def init_backend(self) -> Union[str, Callable]:
|
||||
if self.level == CompilationLevel.NO_COMPILATION:
|
||||
raise ValueError("No compilation level is set.")
|
||||
|
||||
from torch._dynamo.backends.registry import list_backends
|
||||
torch_backends = list_backends(exclude_tags=tuple())
|
||||
if self.level in [
|
||||
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
|
||||
]:
|
||||
if self.backend == "":
|
||||
return "eager"
|
||||
if self.backend in torch_backends:
|
||||
return self.backend
|
||||
return resolve_obj_by_qualname(self.backend)
|
||||
|
||||
# TODO: pass user-specified backend to piecewise compilation
|
||||
# merge with the config use_inductor
|
||||
assert self.level == CompilationLevel.PIECEWISE
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
return VllmBackend(self)
|
||||
|
||||
def init_during_runtime(self):
|
||||
"""To complete the initialization of config,
|
||||
we need to know the compile context, which is only available
|
||||
|
||||
@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -12,8 +10,6 @@ if TYPE_CHECKING:
|
||||
else:
|
||||
VllmConfig = None
|
||||
|
||||
set_torch_compile_backend("openxla")
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
@ -38,3 +34,6 @@ class TpuPlatform(Platform):
|
||||
compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
assert compilation_config.level < CompilationLevel.PIECEWISE,\
|
||||
"TPU does not support Inductor."
|
||||
|
||||
if compilation_config.backend == "":
|
||||
compilation_config.backend = "openxla"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
@ -50,18 +50,6 @@ def load_general_plugins():
|
||||
logger.exception("Failed to load plugin %s", plugin.name)
|
||||
|
||||
|
||||
_torch_compile_backend: Optional[Union[Callable, str]] = None
|
||||
|
||||
|
||||
def set_torch_compile_backend(backend: Union[Callable, str]):
|
||||
global _torch_compile_backend
|
||||
_torch_compile_backend = backend
|
||||
|
||||
|
||||
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
||||
return _torch_compile_backend
|
||||
|
||||
|
||||
_compilation_config: Optional[CompilationConfig] = None
|
||||
|
||||
|
||||
|
||||
@ -1600,3 +1600,12 @@ def direct_register_custom_op(
|
||||
my_lib.impl(op_name, op_func, "CUDA")
|
||||
if fake_impl is not None:
|
||||
my_lib._register_fake(op_name, fake_impl)
|
||||
|
||||
|
||||
def resolve_obj_by_qualname(qualname: str) -> Any:
|
||||
"""
|
||||
Resolve an object by its fully qualified name.
|
||||
"""
|
||||
module_name, obj_name = qualname.rsplit(".", 1)
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, obj_name)
|
||||
|
||||
@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
|
||||
if self.vllm_config.compilation_config.level ==\
|
||||
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend() or "eager"
|
||||
backend = self.vllm_config.compilation_config.init_backend()
|
||||
self.model = torch.compile(
|
||||
self.model,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user