[Bugfix] Fix FP8 torch._scaled_mm fallback for torch>2.5 with CUDA<12.4 (#10095)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
719c1ca468
commit
4ab3256644
@ -7,8 +7,7 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||||
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() \
|
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
|
||||||
if current_platform.is_rocm() else None
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported() -> bool:
|
def cutlass_fp8_supported() -> bool:
|
||||||
@ -166,8 +165,7 @@ def apply_fp8_linear(
|
|||||||
|
|
||||||
# Making sure the dummy tensor is on the same device as the weight
|
# Making sure the dummy tensor is on the same device as the weight
|
||||||
global TORCH_DEVICE_IDENTITY
|
global TORCH_DEVICE_IDENTITY
|
||||||
if (TORCH_DEVICE_IDENTITY is not None
|
if TORCH_DEVICE_IDENTITY.device != weight.device:
|
||||||
and TORCH_DEVICE_IDENTITY.device != weight.device):
|
|
||||||
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
|
||||||
|
|
||||||
# GEMM
|
# GEMM
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user