[torch.compile] use interpreter with stable api from pytorch (#9889)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
4581d2cc02
commit
aff1fd8188
@ -243,6 +243,65 @@ def split_graph(graph: fx.GraphModule,
|
||||
return split_gm, outputs
|
||||
|
||||
|
||||
# we share the global graph pool among all the backends
|
||||
global_graph_pool = None
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
submodules specified by `compile_submod_names` with the given
|
||||
compilation configs.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: List[str],
|
||||
compilation_configs: CompilationConfig, graph_pool):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_configs = compilation_configs
|
||||
self.graph_pool = graph_pool
|
||||
self.have_seen_first_graph = False
|
||||
|
||||
def run(self, *args):
|
||||
fake_args = [
|
||||
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
|
||||
for t in args
|
||||
]
|
||||
return super().run(*fake_args)
|
||||
|
||||
def call_module(self, target: torch.fx.node.Target,
|
||||
args: Tuple[torch.fx.node.Argument,
|
||||
...], kwargs: Dict[str, Any]) -> Any:
|
||||
assert isinstance(target, str)
|
||||
output = super().call_module(target, args, kwargs)
|
||||
|
||||
if target in self.compile_submod_names:
|
||||
submod = self.fetch_attr(target)
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
compiled_graph_for_general_shape = wrap_inductor(
|
||||
submod,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
runtime_shape=None,
|
||||
do_logging=not self.have_seen_first_graph,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
submod, self.compilation_configs, self.graph_pool,
|
||||
not self.have_seen_first_graph, sym_shape_indices,
|
||||
compiled_graph_for_general_shape)
|
||||
|
||||
self.have_seen_first_graph = True
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VllmBackend:
|
||||
"""The compilation backend for `torch.compile` with VLLM.
|
||||
It is used for compilation level of `CompilationLevel.PIECEWISE`,
|
||||
@ -263,8 +322,14 @@ class VllmBackend:
|
||||
returned_callable: Callable
|
||||
|
||||
def __init__(self, ):
|
||||
# every instance of VllmBackend has its own graph pool
|
||||
self.graph_pool = torch.cuda.graph_pool_handle()
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = global_graph_pool
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
@ -286,55 +351,26 @@ class VllmBackend:
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.non_cudagraph_ops)
|
||||
|
||||
returned_callable: Callable # type: ignore
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug("%s",
|
||||
lazy_format_graph_code("stiching module", self.split_gm))
|
||||
|
||||
if len(self.piecewise_graphs) == 0:
|
||||
compilation_counter.num_piecewise_graphs_seen += 1
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
returned_callable = PiecewiseBackend(graph,
|
||||
self.compilation_configs,
|
||||
self.graph_pool,
|
||||
is_first_graph=True)
|
||||
else:
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
logger.debug(
|
||||
"%s", lazy_format_graph_code("stiching module", self.split_gm))
|
||||
compilation_counter.num_piecewise_graphs_seen += len(
|
||||
self.piecewise_graphs)
|
||||
submod_names_to_compile = [
|
||||
item.submod_name for item in self.piecewise_graphs
|
||||
if not item.is_splitting_graph
|
||||
]
|
||||
|
||||
is_first_graph = True
|
||||
|
||||
for item in self.piecewise_graphs:
|
||||
compilation_counter.num_piecewise_graphs_seen += 1
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += not item.is_splitting_graph # noqa
|
||||
if not item.is_splitting_graph:
|
||||
# cannot setattr to a module, so we need to set
|
||||
# the attribute in the __dict__
|
||||
self.split_gm.__dict__[
|
||||
item.submod_name] = PiecewiseBackend(
|
||||
item.graph, self.compilation_configs,
|
||||
self.graph_pool, is_first_graph)
|
||||
is_first_graph = False
|
||||
returned_callable = self.split_gm
|
||||
|
||||
self.returned_callable = returned_callable
|
||||
# trigger the first compilation
|
||||
# code borrowed from https://github.com/pytorch/pytorch/blob/4e3e08b71171fa34172b2362ff668553fac75f27/torch/_dynamo/backends/distributed.py#L206 # noqa
|
||||
# to turn the inputs into fake tensors
|
||||
import torch._guards
|
||||
from torch._guards import detect_fake_mode
|
||||
fake_mode = detect_fake_mode(example_inputs)
|
||||
fake_args = []
|
||||
for arg in example_inputs:
|
||||
if isinstance(arg, torch.Tensor) and not isinstance(
|
||||
arg, torch._subclasses.FakeTensor):
|
||||
fake_args.append(
|
||||
torch._dynamo.utils.to_fake_tensor(arg, fake_mode))
|
||||
else:
|
||||
fake_args.append(arg)
|
||||
self.returned_callable(*fake_args)
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.compilation_configs,
|
||||
self.graph_pool).run(*example_inputs)
|
||||
|
||||
self._called = True
|
||||
|
||||
return self.returned_callable
|
||||
return self.split_gm
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -352,11 +388,10 @@ class ConcreteSizeEntry:
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self,
|
||||
graph: fx.GraphModule,
|
||||
compilation_configs: CompilationConfig,
|
||||
graph_pool: Any,
|
||||
is_first_graph: bool = False):
|
||||
def __init__(self, graph: fx.GraphModule,
|
||||
compilation_configs: CompilationConfig, graph_pool: Any,
|
||||
is_first_graph: bool, sym_shape_indices: List[int],
|
||||
compiled_graph_for_general_shape: Callable):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
@ -381,12 +416,11 @@ class PiecewiseBackend:
|
||||
self.compilation_configs.capture_sizes
|
||||
) if self.compilation_configs.use_cudagraph else set()
|
||||
|
||||
self.compile_finished = False
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape: Callable = None # type: ignore
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices: List[int] = []
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
@ -399,27 +433,6 @@ class PiecewiseBackend:
|
||||
)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
|
||||
if not self.compile_finished:
|
||||
self.compile_finished = True
|
||||
|
||||
# this is the first compilation, we will compile a graph with
|
||||
# dynamic shape, as the caller will mark first dimension as dynamic
|
||||
|
||||
self.sym_shape_indices = [
|
||||
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
|
||||
self.compiled_graph_for_general_shape = wrap_inductor(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_configs.inductor_compile_config,
|
||||
runtime_shape=None,
|
||||
do_logging=self.is_first_graph,
|
||||
use_inductor=self.compilation_configs.use_inductor)
|
||||
|
||||
return self.graph(*args)
|
||||
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user