diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 942215da..5e142e8c 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -22,6 +22,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.platforms import current_platform from vllm.utils import print_warning_once logger = init_logger(__name__) @@ -490,6 +491,11 @@ def initialize_dummy_weights( """ for param in model.state_dict().values(): if torch.is_floating_point(param): + if current_platform.is_tpu(): + # XLA device does not support torch.Generator() + param.uniform_(low, high) + continue + generator = torch.Generator(device=param.data.device) generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: