[ci][bugfix] fix kernel tests (#10431)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
a03ea40792
commit
2298e69b5f
@ -6,9 +6,6 @@ import vllm.envs as envs
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import CompilationConfig, VllmConfig
|
from vllm.config import CompilationConfig, VllmConfig
|
||||||
else:
|
|
||||||
CompilationConfig = None
|
|
||||||
VllmConfig = None
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -50,23 +47,23 @@ def load_general_plugins():
|
|||||||
logger.exception("Failed to load plugin %s", plugin.name)
|
logger.exception("Failed to load plugin %s", plugin.name)
|
||||||
|
|
||||||
|
|
||||||
_compilation_config: Optional[CompilationConfig] = None
|
_compilation_config: Optional["CompilationConfig"] = None
|
||||||
|
|
||||||
|
|
||||||
def set_compilation_config(config: Optional[CompilationConfig]):
|
def set_compilation_config(config: Optional["CompilationConfig"]):
|
||||||
global _compilation_config
|
global _compilation_config
|
||||||
_compilation_config = config
|
_compilation_config = config
|
||||||
|
|
||||||
|
|
||||||
def get_compilation_config() -> Optional[CompilationConfig]:
|
def get_compilation_config() -> Optional["CompilationConfig"]:
|
||||||
return _compilation_config
|
return _compilation_config
|
||||||
|
|
||||||
|
|
||||||
_current_vllm_config: Optional[VllmConfig] = None
|
_current_vllm_config: Optional["VllmConfig"] = None
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_current_vllm_config(vllm_config: VllmConfig):
|
def set_current_vllm_config(vllm_config: "VllmConfig"):
|
||||||
"""
|
"""
|
||||||
Temporarily set the current VLLM config.
|
Temporarily set the current VLLM config.
|
||||||
Used during model initialization.
|
Used during model initialization.
|
||||||
@ -87,6 +84,12 @@ def set_current_vllm_config(vllm_config: VllmConfig):
|
|||||||
_current_vllm_config = old_vllm_config
|
_current_vllm_config = old_vllm_config
|
||||||
|
|
||||||
|
|
||||||
def get_current_vllm_config() -> VllmConfig:
|
def get_current_vllm_config() -> "VllmConfig":
|
||||||
assert _current_vllm_config is not None, "Current VLLM config is not set."
|
if _current_vllm_config is None:
|
||||||
|
# in ci, usually when we test custom ops/modules directly,
|
||||||
|
# we don't set the vllm config. In that case, we set a default
|
||||||
|
# config.
|
||||||
|
logger.warning("Current VLLM config is not set.")
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
return VllmConfig()
|
||||||
return _current_vllm_config
|
return _current_vllm_config
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user