[Bugfix][TPU] Do not use torch.Generator for TPUs (#6981)

This commit is contained in:
Woosuk Kwon 2024-07-31 18:50:28 -07:00 committed by GitHub
parent 1d2e7fb73f
commit 23993a7997
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig, from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config) get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.platforms import current_platform
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
@ -490,6 +491,11 @@ def initialize_dummy_weights(
""" """
for param in model.state_dict().values(): for param in model.state_dict().values():
if torch.is_floating_point(param): if torch.is_floating_point(param):
if current_platform.is_tpu():
# XLA device does not support torch.Generator()
param.uniform_(low, high)
continue
generator = torch.Generator(device=param.data.device) generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed) generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16: if torch.finfo(param.data.dtype).bits < 16: