[ Kernel ] Enable Dynamic Per Token fp8 (#6547)
This commit is contained in:
parent
07eb6f19f3
commit
4cc24f01b1
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.769
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.769
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@ -3,4 +3,5 @@ Meta-Llama-3-8B-Instruct-FP8.yaml
|
|||||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||||
|
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
|||||||
@ -27,7 +27,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
|||||||
device="cuda") + 1e-6 # avoid nans
|
device="cuda") + 1e-6 # avoid nans
|
||||||
|
|
||||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
||||||
ops_out, ops_scales = ops.dynamic_per_token_scaled_fp8_quant(x)
|
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
assert torch.allclose(ref_scales, ops_scales)
|
assert torch.allclose(ref_scales, ops_scales)
|
||||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||||
|
|||||||
@ -300,6 +300,7 @@ def scaled_fp8_quant(
|
|||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
scale: Optional[torch.Tensor] = None,
|
scale: Optional[torch.Tensor] = None,
|
||||||
batch_dim_padding: Optional[int] = None,
|
batch_dim_padding: Optional[int] = None,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Quantize input tensor to FP8 and return quantized tensor and scale.
|
Quantize input tensor to FP8 and return quantized tensor and scale.
|
||||||
@ -315,6 +316,8 @@ def scaled_fp8_quant(
|
|||||||
scale: Optional scaling factor for the FP8 quantization
|
scale: Optional scaling factor for the FP8 quantization
|
||||||
batch_dim_padding: If specified, pad the first dimension
|
batch_dim_padding: If specified, pad the first dimension
|
||||||
of the output to at least this value.
|
of the output to at least this value.
|
||||||
|
use_per_token_if_dynamic: Whether to do per_tensor or per_token
|
||||||
|
in the dynamic quantization case.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
|
||||||
@ -328,24 +331,21 @@ def scaled_fp8_quant(
|
|||||||
else:
|
else:
|
||||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||||
if scale is None:
|
if scale is None:
|
||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
if use_per_token_if_dynamic:
|
||||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||||
|
output, input, scale)
|
||||||
|
else:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
else:
|
else:
|
||||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||||
|
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
def dynamic_per_token_scaled_fp8_quant(
|
|
||||||
input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
|
|
||||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
|
||||||
scales = torch.empty((input.numel() // input.shape[-1], 1),
|
|
||||||
device=input.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(output, input, scales)
|
|
||||||
return output, scales
|
|
||||||
|
|
||||||
|
|
||||||
# int8
|
# int8
|
||||||
def scaled_int8_quant(
|
def scaled_int8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
|||||||
@ -103,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
|||||||
@ -214,7 +214,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=layer.input_scale,
|
input_scale=layer.input_scale,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
cutlass_fp8_supported=self.cutlass_fp8_supported)
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||||
|
use_per_token_if_dynamic=False)
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|||||||
@ -107,31 +107,43 @@ def apply_fp8_linear(
|
|||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
cutlass_fp8_supported: bool = True,
|
cutlass_fp8_supported: bool = True,
|
||||||
|
use_per_token_if_dynamic: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
|
|
||||||
|
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||||
if cutlass_fp8_supported:
|
if cutlass_fp8_supported:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
|
input,
|
||||||
|
input_scale,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output = ops.cutlass_scaled_mm(qinput,
|
return ops.cutlass_scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
|
# torch.scaled_mm supports per tensor weights + activations only
|
||||||
|
# so fallback to naive if per channel or per token
|
||||||
else:
|
else:
|
||||||
# Note: we pad the input because torch._scaled_mm is more performant
|
# Note: we pad the input because torch._scaled_mm is more performant
|
||||||
# for matrices with batch dimension > 16.
|
# for matrices with batch dimension > 16.
|
||||||
# This could change in the future.
|
# This could change in the future.
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(input,
|
qinput, x_scale = ops.scaled_fp8_quant(
|
||||||
input_scale,
|
input,
|
||||||
batch_dim_padding=17)
|
input_scale,
|
||||||
|
batch_dim_padding=17,
|
||||||
|
use_per_token_if_dynamic=use_per_token_if_dynamic)
|
||||||
|
|
||||||
if weight_scale.numel() == 1:
|
per_tensor_weights = (weight_scale.numel() == 1)
|
||||||
|
per_tensor_activations = (x_scale.numel() == 1)
|
||||||
|
|
||||||
|
if per_tensor_weights and per_tensor_activations:
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output, _ = torch._scaled_mm(qinput,
|
output, _ = torch._scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
@ -139,9 +151,11 @@ def apply_fp8_linear(
|
|||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
return torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where the weight scales are
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
# applied separately.
|
# due to limitations with scaled_mm
|
||||||
|
|
||||||
# Symmetric quantized GEMM by definition computes the following:
|
# Symmetric quantized GEMM by definition computes the following:
|
||||||
# C = (s_x * X) (s_w * W) + bias
|
# C = (s_x * X) (s_w * W) + bias
|
||||||
@ -155,21 +169,21 @@ def apply_fp8_linear(
|
|||||||
# For the scaled_mm fallback case, we break this down, since it
|
# For the scaled_mm fallback case, we break this down, since it
|
||||||
# does not support s_w being a vector.
|
# does not support s_w being a vector.
|
||||||
|
|
||||||
# This computes C = sx * (X * W).
|
# GEMM
|
||||||
|
# This computes C = (X * W).
|
||||||
# Output in fp32 to allow subsequent ops to happen in-place
|
# Output in fp32 to allow subsequent ops to happen in-place
|
||||||
output, _ = torch._scaled_mm(qinput,
|
output, _ = torch._scaled_mm(qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=torch.float32,
|
out_dtype=torch.float32)
|
||||||
scale_a=x_scale)
|
# Unpad (undo batch_dim_padding)
|
||||||
|
output = torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
|
||||||
# C = sw * sx * (X * W)
|
# DQ
|
||||||
output = output * weight_scale.t()
|
# C = sw * sx * (X * W) + bias
|
||||||
|
output = output * x_scale * weight_scale.t()
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
# C = sw * sx * (X * W) + bias
|
|
||||||
output = output + bias
|
output = output + bias
|
||||||
output = output.to(dtype=input.dtype)
|
return output.to(dtype=input.dtype)
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, input.shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
def apply_int8_linear(
|
def apply_int8_linear(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user