diff --git a/vllm/config.py b/vllm/config.py index f57aa404..00dd047e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1408,11 +1408,11 @@ class SpeculativeConfig: else: speculative_draft_tensor_parallel_size = \ target_parallel_config.tensor_parallel_size - elif speculative_draft_tensor_parallel_size != 1: - # TODO(wooyeon): allow tp values larger than 1 + elif speculative_draft_tensor_parallel_size not in ( + 1, target_parallel_config.tensor_parallel_size): raise ValueError( f"{speculative_draft_tensor_parallel_size=} cannot be " - f"other value than 1") + f"other value than 1 or target model tensor_parallel_size") draft_parallel_config = ParallelConfig( pipeline_parallel_size=target_parallel_config.