[torch.compile] add a flag to disable custom op (#8488)

This commit is contained in:
youkaichao 2024-09-14 13:07:16 -07:00 committed by GitHub
parent a36e070dad
commit 47790f3e32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 1 deletions

View File

@ -6,7 +6,8 @@ import pytest
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
def test_full_graph(model):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
from vllm import LLM, SamplingParams
prompts = [

View File

@ -202,6 +202,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(

View File

@ -1,5 +1,6 @@
import torch.nn as nn
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
@ -53,6 +54,10 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
return self.forward_native
if is_hip():
return self.forward_hip
elif is_cpu():