diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 445117ac..ec735331 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -7,8 +7,7 @@ from vllm.platforms import current_platform # Input scaling factors are no longer optional in _scaled_mm starting # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale -TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \ - if current_platform.is_rocm() else None +TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) def cutlass_fp8_supported() -> bool: @@ -166,8 +165,7 @@ def apply_fp8_linear( # Making sure the dummy tensor is on the same device as the weight global TORCH_DEVICE_IDENTITY - if (TORCH_DEVICE_IDENTITY is not None - and TORCH_DEVICE_IDENTITY.device != weight.device): + if TORCH_DEVICE_IDENTITY.device != weight.device: TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) # GEMM