[Spec Decoding] Use target model max length as default for draft model (#7706)
This commit is contained in:
parent
6925cdbeea
commit
9b73a2f498
@ -127,6 +127,7 @@ class ModelConfig:
|
|||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
|
spec_target_max_model_len: Optional[int] = None,
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[str] = None,
|
||||||
quantization_param_path: Optional[str] = None,
|
quantization_param_path: Optional[str] = None,
|
||||||
enforce_eager: Optional[bool] = None,
|
enforce_eager: Optional[bool] = None,
|
||||||
@ -210,7 +211,8 @@ class ModelConfig:
|
|||||||
hf_config=self.hf_text_config,
|
hf_config=self.hf_text_config,
|
||||||
max_model_len=max_model_len,
|
max_model_len=max_model_len,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
sliding_window_len=self.get_hf_config_sliding_window())
|
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||||
|
spec_target_max_model_len=spec_target_max_model_len)
|
||||||
self.served_model_name = get_served_model_name(model,
|
self.served_model_name = get_served_model_name(model,
|
||||||
served_model_name)
|
served_model_name)
|
||||||
self.multimodal_config = self._init_multimodal_config(
|
self.multimodal_config = self._init_multimodal_config(
|
||||||
@ -1134,6 +1136,7 @@ class SpeculativeConfig:
|
|||||||
code_revision=draft_code_revision,
|
code_revision=draft_code_revision,
|
||||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
|
spec_target_max_model_len=target_model_config.max_model_len,
|
||||||
quantization=draft_quantization,
|
quantization=draft_quantization,
|
||||||
enforce_eager=target_model_config.enforce_eager,
|
enforce_eager=target_model_config.enforce_eager,
|
||||||
max_seq_len_to_capture=target_model_config.
|
max_seq_len_to_capture=target_model_config.
|
||||||
@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
|
|||||||
max_model_len: Optional[int],
|
max_model_len: Optional[int],
|
||||||
disable_sliding_window: bool,
|
disable_sliding_window: bool,
|
||||||
sliding_window_len: Optional[int],
|
sliding_window_len: Optional[int],
|
||||||
|
spec_target_max_model_len: Optional[int] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Get and verify the model's maximum length."""
|
"""Get and verify the model's maximum length."""
|
||||||
derived_max_model_len = float("inf")
|
derived_max_model_len = float("inf")
|
||||||
@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
|
|||||||
# If max_model_len is specified, we use it.
|
# If max_model_len is specified, we use it.
|
||||||
return max_model_len
|
return max_model_len
|
||||||
|
|
||||||
|
if spec_target_max_model_len is not None:
|
||||||
|
# If this is a speculative draft model, we use the max model len
|
||||||
|
# from the target model.
|
||||||
|
return spec_target_max_model_len
|
||||||
|
|
||||||
default_max_len = 2048
|
default_max_len = 2048
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The model's config.json does not contain any of the following "
|
"The model's config.json does not contain any of the following "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user