Fix assertion failure in Qwen 1.5 with prefix caching enabled (#3373)
Co-authored-by: Cade Daniel <edacih@gmail.com>
This commit is contained in:
parent
dfc77408bd
commit
54be8a0be2
43
tests/test_config.py
Normal file
43
tests/test_config.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_sliding_window():
|
||||||
|
TEST_SLIDING_WINDOW = 4096
|
||||||
|
# Test that the sliding window is correctly computed.
|
||||||
|
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
||||||
|
# when use_sliding_window is False.
|
||||||
|
qwen2_model_config = ModelConfig(
|
||||||
|
"Qwen/Qwen1.5-7B",
|
||||||
|
"Qwen/Qwen1.5-7B",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
qwen2_model_config.hf_config.use_sliding_window = False
|
||||||
|
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||||
|
assert qwen2_model_config.get_sliding_window() is None
|
||||||
|
|
||||||
|
qwen2_model_config.hf_config.use_sliding_window = True
|
||||||
|
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||||
|
|
||||||
|
mistral_model_config = ModelConfig(
|
||||||
|
"mistralai/Mistral-7B-v0.1",
|
||||||
|
"mistralai/Mistral-7B-v0.1",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
download_dir=None,
|
||||||
|
load_format="dummy",
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
revision=None,
|
||||||
|
)
|
||||||
|
mistral_model_config.hf_config.sliding_window = None
|
||||||
|
assert mistral_model_config.get_sliding_window() is None
|
||||||
|
|
||||||
|
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||||
|
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||||
@ -103,6 +103,7 @@ class ModelConfig:
|
|||||||
# download model from ModelScope hub,
|
# download model from ModelScope hub,
|
||||||
# lazy import so that modelscope is not required for normal use.
|
# lazy import so that modelscope is not required for normal use.
|
||||||
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
||||||
|
|
||||||
if not os.path.exists(model):
|
if not os.path.exists(model):
|
||||||
model_path = snapshot_download(model_id=model,
|
model_path = snapshot_download(model_id=model,
|
||||||
cache_dir=download_dir,
|
cache_dir=download_dir,
|
||||||
@ -139,7 +140,7 @@ class ModelConfig:
|
|||||||
if (f not in rocm_not_supported_load_format)
|
if (f not in rocm_not_supported_load_format)
|
||||||
]
|
]
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
f"load format '{load_format}' is not supported in ROCm. "
|
||||||
f"Supported load format are "
|
f"Supported load format are "
|
||||||
f"{rocm_supported_load_format}")
|
f"{rocm_supported_load_format}")
|
||||||
|
|
||||||
@ -232,6 +233,15 @@ class ModelConfig:
|
|||||||
f"({pipeline_parallel_size}).")
|
f"({pipeline_parallel_size}).")
|
||||||
|
|
||||||
def get_sliding_window(self) -> Optional[int]:
|
def get_sliding_window(self) -> Optional[int]:
|
||||||
|
"""Get the sliding window size, or None if disabled.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
|
||||||
|
# addition to sliding window size. We check if that field is present
|
||||||
|
# and if it's False, return None.
|
||||||
|
if (hasattr(self.hf_config, "use_sliding_window")
|
||||||
|
and not self.hf_config.use_sliding_window):
|
||||||
|
return None
|
||||||
return getattr(self.hf_config, "sliding_window", None)
|
return getattr(self.hf_config, "sliding_window", None)
|
||||||
|
|
||||||
def get_vocab_size(self) -> int:
|
def get_vocab_size(self) -> int:
|
||||||
@ -624,7 +634,7 @@ def _get_and_verify_dtype(
|
|||||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||||
]
|
]
|
||||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
|
||||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||||
|
|
||||||
# Verify the dtype.
|
# Verify the dtype.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user