[4/N][torch.compile] clean up set_torch_compile_backend (#10401)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-17 23:57:20 -08:00 committed by GitHub
parent 47826cacf0
commit 51bb12d17b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 49 additions and 42 deletions

View File

@ -2,15 +2,14 @@ import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Union)
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
from unittest.mock import patch
import torch
import torch.fx as fx
import vllm.envs as envs
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from vllm.utils import combine_fx_passes, weak_ref_tensors
@ -684,14 +683,3 @@ class PiecewiseBackend:
entry.cudagraph.replay()
return entry.output
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend_str = "eager"
return backend_str
assert level == CompilationLevel.PIECEWISE
from vllm.plugins import get_current_vllm_config
compilation_config = get_current_vllm_config().compilation_config
return VllmBackend(compilation_config)

View File

@ -32,14 +32,9 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(compilation_level)
from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config(
).compilation_config.init_backend()
compiled_callable = torch.compile(
self.forward,

View File

@ -22,7 +22,7 @@ from vllm.transformers_utils.config import (
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope)
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
identity, print_warning_once)
identity, print_warning_once, resolve_obj_by_qualname)
if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
@ -2072,6 +2072,13 @@ class CompilationConfig(BaseModel):
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- backend: the backend for compilation. It needs to be a string.
- "" (empty string): use the default backend.
- "eager"/"openxla"/...: use the specified backend registered in PyTorch.
- "full.module.name": a qualified name which can be used to import the backend function.
We use string to avoid serialization issues when using compilation in a distributed setting.
When the compilation level is 1 or 2, the backend is used for the compilation directly (it sees the whole graph).
When the compilation level is 3, the backend is used for the piecewise compilation (it sees a part of the graph).
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
@ -2139,6 +2146,7 @@ class CompilationConfig(BaseModel):
certain small batchsizes, where inductor is good at optimizing.
""" # noqa
level: int = 0
backend: str = ""
custom_ops: List[str] = Field(default_factory=list)
use_inductor: bool = True
@ -2182,6 +2190,27 @@ class CompilationConfig(BaseModel):
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
def init_backend(self) -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.")
from torch._dynamo.backends.registry import list_backends
torch_backends = list_backends(exclude_tags=tuple())
if self.level in [
CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE
]:
if self.backend == "":
return "eager"
if self.backend in torch_backends:
return self.backend
return resolve_obj_by_qualname(self.backend)
# TODO: pass user-specified backend to piecewise compilation
# merge with the config use_inductor
assert self.level == CompilationLevel.PIECEWISE
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available

View File

@ -3,8 +3,6 @@ from typing import TYPE_CHECKING
import torch
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
@ -12,8 +10,6 @@ if TYPE_CHECKING:
else:
VllmConfig = None
set_torch_compile_backend("openxla")
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
@ -38,3 +34,6 @@ class TpuPlatform(Platform):
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if compilation_config.backend == "":
compilation_config.backend = "openxla"

View File

@ -1,6 +1,6 @@
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Optional
import vllm.envs as envs
@ -50,18 +50,6 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name)
_torch_compile_backend: Optional[Union[Callable, str]] = None
def set_torch_compile_backend(backend: Union[Callable, str]):
global _torch_compile_backend
_torch_compile_backend = backend
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
_compilation_config: Optional[CompilationConfig] = None

View File

@ -1600,3 +1600,12 @@ def direct_register_custom_op(
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
def resolve_obj_by_qualname(qualname: str) -> Any:
"""
Resolve an object by its fully qualified name.
"""
module_name, obj_name = qualname.rsplit(".", 1)
module = importlib.import_module(module_name)
return getattr(module, obj_name)

View File

@ -1143,8 +1143,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
if self.vllm_config.compilation_config.level ==\
CompilationLevel.DYNAMO_AS_IS and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or "eager"
backend = self.vllm_config.compilation_config.init_backend()
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,