[Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-08-04 16:13:18 +02:00 committed by GitHub
parent 179a6a36f2
commit b1c9aa3daa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1068,7 +1068,7 @@ class SpeculativeConfig:
draft_parallel_config = ( draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config( SpeculativeConfig.create_draft_parallel_config(
target_parallel_config, target_parallel_config,
speculative_draft_tensor_parallel_size)) speculative_draft_tensor_parallel_size, draft_hf_config))
if num_speculative_tokens is None: if num_speculative_tokens is None:
raise ValueError( raise ValueError(
@ -1136,15 +1136,23 @@ class SpeculativeConfig:
@staticmethod @staticmethod
def create_draft_parallel_config( def create_draft_parallel_config(
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int] speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig,
) -> ParallelConfig: ) -> ParallelConfig:
"""Create a parallel config for use by the draft worker. """Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size. This is mostly a copy of the target parallel config, except the tp_size.
""" """
if speculative_draft_tensor_parallel_size is None: if speculative_draft_tensor_parallel_size is None:
speculative_draft_tensor_parallel_size = \ if draft_hf_config.model_type == "mlp_speculator":
target_parallel_config.tensor_parallel_size speculative_draft_tensor_parallel_size = 1
if target_parallel_config.tensor_parallel_size > 1:
logger.warning(
"MLPSpeculator cannot currently be run with tp>1; "
"setting speculative_draft_tensor_parallel_size=1")
else:
speculative_draft_tensor_parallel_size = \
target_parallel_config.tensor_parallel_size
elif speculative_draft_tensor_parallel_size != 1: elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1 # TODO(wooyeon): allow tp values larger than 1
raise ValueError( raise ValueError(