[Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-08-04 16:13:18 +02:00 committed by GitHub
parent 179a6a36f2
commit b1c9aa3daa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1068,7 +1068,7 @@ class SpeculativeConfig:
draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
target_parallel_config,
speculative_draft_tensor_parallel_size))
speculative_draft_tensor_parallel_size, draft_hf_config))
if num_speculative_tokens is None:
raise ValueError(
@ -1136,15 +1136,23 @@ class SpeculativeConfig:
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int]
speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(