[torch.compile] add a flag to disable custom op (#8488)
This commit is contained in:
parent
a36e070dad
commit
47790f3e32
@ -6,7 +6,8 @@ import pytest
|
||||
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
|
||||
def test_full_graph(model):
|
||||
# make sure these models can be captured in full graph mode
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
prompts = [
|
||||
|
||||
@ -202,6 +202,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Internal flag to control whether we use custom op,
|
||||
# or use the native pytorch implementation
|
||||
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
|
||||
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
|
||||
|
||||
# Internal flag to enable Dynamo fullgraph capture
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
||||
lambda: bool(
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_xpu
|
||||
|
||||
@ -53,6 +54,10 @@ class CustomOp(nn.Module):
|
||||
def dispatch_forward(self):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
|
||||
return self.forward_native
|
||||
|
||||
if is_hip():
|
||||
return self.forward_hip
|
||||
elif is_cpu():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user