[Kernel] AQ AZP 4/4: Integrate asymmetric quantization to linear method (#7271)
This commit is contained in:
parent
a9b15c606f
commit
172d1cd276
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.764
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.764
|
||||
limit: 250
|
||||
num_fewshot: 5
|
||||
@ -1,6 +1,7 @@
|
||||
Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base-FP8.yaml
|
||||
|
||||
@ -49,10 +49,15 @@ def test_lm_eval_correctness():
|
||||
results = launch_lm_eval(eval_config)
|
||||
|
||||
# Confirm scores match ground truth.
|
||||
success = True
|
||||
for task in eval_config["tasks"]:
|
||||
for metric in task["metrics"]:
|
||||
ground_truth = metric["value"]
|
||||
measured_value = results["results"][task["name"]][metric["name"]]
|
||||
print(f'{task["name"]} | {metric["name"]}: '
|
||||
f'ground_truth={ground_truth} | measured={measured_value}')
|
||||
assert numpy.isclose(ground_truth, measured_value, rtol=RTOL)
|
||||
success = success and numpy.isclose(
|
||||
ground_truth, measured_value, rtol=RTOL)
|
||||
|
||||
# Assert at the end, print all scores even on failure for debugging.
|
||||
assert success
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
Run `pytest tests/quantization/test_compressed_tensors.py`.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -14,14 +15,16 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationType)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_args", [
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
|
||||
QuantizationType.INT, 2560),
|
||||
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
|
||||
QuantizationType.INT, 2560),
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"model_args",
|
||||
[("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor",
|
||||
QuantizationType.INT, 2560, True),
|
||||
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel",
|
||||
QuantizationType.INT, 2560, True),
|
||||
("nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama", "tensor",
|
||||
QuantizationType.INT, 2560, False)])
|
||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
model_path, strategy, quant_type, shape_0 = model_args
|
||||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
layer = model.model.layers[0]
|
||||
@ -31,6 +34,18 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
gate_up_proj = layer.mlp.gate_up_proj
|
||||
down_proj = layer.mlp.down_proj
|
||||
|
||||
# assert zp for symmetric and asymmetric cases
|
||||
def zp_valid(zp: Optional[torch.Tensor]):
|
||||
if is_symmetric:
|
||||
return zp is None
|
||||
|
||||
return zp is not None and zp.dtype is torch.int32
|
||||
|
||||
assert zp_valid(qkv_proj.input_zero_point)
|
||||
assert zp_valid(o_proj.input_zero_point)
|
||||
assert zp_valid(gate_up_proj.input_zero_point)
|
||||
assert zp_valid(down_proj.input_zero_point)
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(gate_up_proj.quant_method,
|
||||
@ -69,9 +84,12 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||
|
||||
@pytest.mark.parametrize("model_args", [
|
||||
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
|
||||
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
|
||||
"channel"),
|
||||
])
|
||||
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
||||
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
||||
model_path, strategy = model_args
|
||||
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
@ -160,4 +178,4 @@ def test_compressed_tensors_kv_cache(vllm_runner):
|
||||
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
||||
output = llm.generate_greedy("Hello world!", max_tokens=20)
|
||||
assert output
|
||||
assert output
|
||||
|
||||
@ -138,10 +138,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_tensor = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
return is_8_bits and is_tensor and is_symmetric and is_static
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
@ -151,10 +152,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||
is_token = (weight_strategy and input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
@ -265,12 +267,14 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True)
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False)
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Callable, List, Optional
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
@ -14,12 +15,16 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool,
|
||||
input_symmetric: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
@ -46,10 +51,43 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
requires_grad=False)
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
if self.input_symmetric:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
layer.input_zero_point = None
|
||||
else:
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = layer.input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (layer.input_scale *
|
||||
(int8_traits.max - azps)).max()
|
||||
range_min = (layer.input_scale *
|
||||
(int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max -
|
||||
int8_traits.min)
|
||||
layer.input_scale = Parameter(scale, requires_grad=False)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min -
|
||||
range_min / scale).to(dtype=torch.int32)
|
||||
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
||||
|
||||
else:
|
||||
layer.input_scale = None
|
||||
layer.input_zero_point = None
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.input_symmetric:
|
||||
layer.azp_adj = layer.weight.sum(dim=0,
|
||||
keepdim=True,
|
||||
dtype=torch.int32)
|
||||
else:
|
||||
layer.azp_adj = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
@ -90,6 +128,15 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
@ -97,4 +144,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
input_zero_point=layer.input_zero_point,
|
||||
azp_adj=layer.azp_adj,
|
||||
bias=bias)
|
||||
|
||||
@ -191,13 +191,28 @@ def apply_int8_linear(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
input_zero_point: Optional[torch.Tensor] = None,
|
||||
azp_adj: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
|
||||
input_scale,
|
||||
input_zero_point,
|
||||
symmetric=symmetric)
|
||||
|
||||
if x_zp is not None:
|
||||
return ops.cutlass_scaled_mm_azp(x_q,
|
||||
weight,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=x_zp,
|
||||
bias=bias)
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
weight,
|
||||
scale_a=x_scale,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user