vllm/vllm/compilation/wrapper.py
youkaichao ce6bf3a2cf
[torch.compile] avoid Dynamo guard evaluation overhead (#7898)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2024-08-28 16:10:12 -07:00

82 lines
3.3 KiB
Python

import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List
import torch
import vllm.envs as envs
class TorchCompileWrapperWithCustomDispacther:
"""
A wrapper class for torch.compile, with a custom dispatch logic.
Subclasses should:
1. Implement the forward method
2. Implement the dispatch logic in the __call__ method
It can use `self.compiled_codes` to access the compiled bytecode,
and `with self.dispatch_to_code(index):` to dispatch to
the compiled code.
3. Implement the `__init__` method to determine how to call
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Callable):
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
# read the env var to determine whether to use the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward
method, for directly dispatching to the compiled code.
"""
return self.compiled_callable(*args, **kwargs)
@abstractmethod
def forward(self, *args, **kwargs):
...
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
"""Hook to save the compiled bytecode for direct execution."""
if old_code is not self.original_code_object:
return
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
frame = sys._getframe()
while True:
frame = frame.f_back
code_name = frame.f_code.co_name
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
if code_name == "_compile" and file_name == "convert_frame.py":
break
frame = frame.f_locals["frame"]
assert frame.f_code == old_code
if frame.f_locals["self"] is not self:
return
self.compiled_codes.append(new_code)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
Why does this work? Because Dynamo guarantees that the compiled
bytecode has exactly the same arguments, cell variables, and free
variables as the original code. Therefore we can directly switch
the code object in the function and call it.
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
""" # noqa
self.__class__.forward.__code__ = self.compiled_codes[index]
yield
self.__class__.forward.__code__ = self.original_code_object