[Bugfix][TPU] Do not use torch.Generator for TPUs (#6981)
This commit is contained in:
parent
1d2e7fb73f
commit
23993a7997
@ -22,6 +22,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -490,6 +491,11 @@ def initialize_dummy_weights(
|
|||||||
"""
|
"""
|
||||||
for param in model.state_dict().values():
|
for param in model.state_dict().values():
|
||||||
if torch.is_floating_point(param):
|
if torch.is_floating_point(param):
|
||||||
|
if current_platform.is_tpu():
|
||||||
|
# XLA device does not support torch.Generator()
|
||||||
|
param.uniform_(low, high)
|
||||||
|
continue
|
||||||
|
|
||||||
generator = torch.Generator(device=param.data.device)
|
generator = torch.Generator(device=param.data.device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
if torch.finfo(param.data.dtype).bits < 16:
|
if torch.finfo(param.data.dtype).bits < 16:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user