[Bugfix][TPU] Do not use torch.Generator for TPUs (#6981)
This commit is contained in:
parent
1d2e7fb73f
commit
23993a7997
@ -22,6 +22,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -490,6 +491,11 @@ def initialize_dummy_weights(
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
# XLA device does not support torch.Generator()
|
||||
param.uniform_(low, high)
|
||||
continue
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user