[Misc] Fix input_scale typing in w8a8_utils.py (#6579)

This commit is contained in:
Michael Goin 2024-07-20 19:11:13 -04:00 committed by GitHub
parent 9364f74eee
commit f952bbc8ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -104,7 +104,7 @@ def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
@ -192,7 +192,7 @@ def apply_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.