diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py index cef516ad..3668c1fa 100644 --- a/tests/compile/test_wrapper.py +++ b/tests/compile/test_wrapper.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher class MyMod(torch.nn.Module): @@ -13,7 +13,7 @@ class MyMod(torch.nn.Module): return x * 2 -class MyWrapper(TorchCompileWrapperWithCustomDispacther): +class MyWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model): self.model = model diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index c3d86329..e923bd36 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -10,7 +10,7 @@ import torch import vllm.envs as envs -class TorchCompileWrapperWithCustomDispacther: +class TorchCompileWrapperWithCustomDispatcher: """ A wrapper class for torch.compile, with a custom dispatch logic. Subclasses should: diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 684c54b7..db306bc7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -11,7 +11,7 @@ import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger @@ -611,7 +611,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(TorchCompileWrapperWithCustomDispacther): +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model: nn.Module): self.model = model