[Kernel] Remove scaled_fp8_quant kernel padding footgun (#6842)
This commit is contained in:
parent
052b6f8ca4
commit
d7a299edaa
@ -123,7 +123,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
# Padding
|
# Padding
|
||||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
||||||
assert y.shape[0] == 17
|
assert y.shape[0] == 17
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
ref_y,
|
ref_y,
|
||||||
|
|||||||
@ -307,7 +307,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
batch_dim_padding: Optional[int] = None,
|
num_token_padding: Optional[int] = None,
|
||||||
scale_ub: Optional[torch.Tensor] = None,
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
use_per_token_if_dynamic: bool = False,
|
use_per_token_if_dynamic: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -317,7 +317,7 @@ def scaled_fp8_quant(
|
|||||||
This function supports both static and dynamic quantization: If you
|
This function supports both static and dynamic quantization: If you
|
||||||
provide the scale, it will use static scaling and if you omit it,
|
provide the scale, it will use static scaling and if you omit it,
|
||||||
the scale will be determined dynamically. The function also allows
|
the scale will be determined dynamically. The function also allows
|
||||||
optional padding of the output tensor for downstream kernels that
|
optional padding of the output tensors for downstream kernels that
|
||||||
will benefit from padding.
|
will benefit from padding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -325,7 +325,7 @@ def scaled_fp8_quant(
|
|||||||
scale: Optional scaling factor for the FP8 quantization
|
scale: Optional scaling factor for the FP8 quantization
|
||||||
scale_ub: Optional upper bound for scaling factor in dynamic
|
scale_ub: Optional upper bound for scaling factor in dynamic
|
||||||
per token case
|
per token case
|
||||||
batch_dim_padding: If specified, pad the first dimension
|
num_token_padding: If specified, pad the first dimension
|
||||||
of the output to at least this value.
|
of the output to at least this value.
|
||||||
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||||
in the dynamic quantization case.
|
in the dynamic quantization case.
|
||||||
@ -334,16 +334,16 @@ def scaled_fp8_quant(
|
|||||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||||
scaling factor.
|
scaling factor.
|
||||||
"""
|
"""
|
||||||
if batch_dim_padding:
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
shape = (max(batch_dim_padding, input.shape[0]), *input.shape[1:])
|
assert (input.ndim == 2)
|
||||||
output = torch.empty(shape,
|
shape = input.shape
|
||||||
device=input.device,
|
if num_token_padding:
|
||||||
dtype=torch.float8_e4m3fn)
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
else:
|
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
||||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
if use_per_token_if_dynamic:
|
if use_per_token_if_dynamic:
|
||||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
scale = torch.empty((shape[0], 1),
|
||||||
device=input.device,
|
device=input.device,
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||||
@ -352,6 +352,8 @@ def scaled_fp8_quant(
|
|||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
else:
|
else:
|
||||||
|
# num_token_padding not implemented for this case
|
||||||
|
assert (scale.numel() == 1 or num_token_padding is None)
|
||||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|||||||
@ -139,7 +139,7 @@ def apply_fp8_linear(
|
|||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
input,
|
input,
|
||||||
input_scale,
|
input_scale,
|
||||||
batch_dim_padding=17,
|
num_token_padding=17,
|
||||||
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
per_tensor_weights = (weight_scale.numel() == 1)
|
per_tensor_weights = (weight_scale.numel() == 1)
|
||||||
@ -177,8 +177,9 @@ def apply_fp8_linear(
|
|||||||
output, _ = torch._scaled_mm(qinput,
|
output, _ = torch._scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=torch.float32)
|
out_dtype=torch.float32)
|
||||||
# Unpad (undo batch_dim_padding)
|
# Unpad (undo num_token_padding)
|
||||||
output = torch.narrow(output, 0, 0, input.shape[0])
|
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
|
||||||
|
|
||||||
# DQ
|
# DQ
|
||||||
# C = sw * sx * (X * W) + bias
|
# C = sw * sx * (X * W) + bias
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user