From c2637a613b6140dc16fecd5a1b0f5a9e1d0932ff Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Thu, 13 Jun 2024 10:19:56 -0400 Subject: [PATCH] [Kernel] `w4a16` support for `compressed-tensors` (#5385) Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> --- tests/quantization/test_compressed_tensors.py | 27 ++- .../compressed_tensors/compressed_tensors.py | 44 ++++- .../compressed_tensors/schemes/__init__.py | 1 + .../schemes/compressed_tensors_w4a16.py | 168 ++++++++++++++++++ 4 files changed, 230 insertions(+), 10 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index e6d8218b..5670498f 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -3,12 +3,13 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. """ +import pytest import torch from vllm import SamplingParams from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsLinearMethod, CompressedTensorsW4A16, + CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) def test_compressed_tensors_w8a8_static_setup(vllm_runner): @@ -60,3 +61,25 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner): assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken) assert qkv_proj.weight.dtype is torch.int8 + + +@pytest.mark.parametrize("w4a16_args", [ + ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None), + ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128), +]) +def test_compressed_tensors_w4a16(vllm_runner, w4a16_args): + model, strategy, group = w4a16_args + with vllm_runner(model) as llm: + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + + qkv_proj = layer.self_attn.qkv_proj + assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod) + assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16) + + assert qkv_proj.scheme.strategy == strategy + assert qkv_proj.scheme.group_size == group + + assert qkv_proj.weight_packed.dtype is torch.int32 + assert qkv_proj.weight_scale.dtype is torch.float16 + assert qkv_proj.weight_packed.pack_factor == 8 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index d2b0ce0d..c7f04784 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -7,8 +7,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, - CompressedTensorsW8A8StaticTensor) + CompressedTensorsScheme, CompressedTensorsW4A16, + CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) @@ -47,16 +47,27 @@ class CompressedTensorsConfig(QuantizationConfig): layer_quant_details: Dict[str, Any] = dict() ignore: List[str] = config.get("ignore", None) + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. for key, quant_config in config["config_groups"].items(): targets = quant_config.get("targets") for target in targets: layer_quant_details[target] = {} layer_quant_details[target][ - "weight"] = QuantizationArgs.parse_obj( + "weights"] = QuantizationArgs.parse_obj( quant_config.get("weights")) - layer_quant_details[target][ - "input"] = QuantizationArgs.parse_obj( - quant_config.get("input_activations")) + try: + layer_quant_details[target][ + "input_activations"] = QuantizationArgs.parse_obj( + quant_config.get("input_activations")) + except Exception: + layer_quant_details[target]["input_activations"] = None return cls(layer_quant_details=layer_quant_details, ignore=ignore) @@ -86,8 +97,23 @@ class CompressedTensorsConfig(QuantizationConfig): return is_8_bits and is_token_tensor and is_symmetric and is_dynamic + def _is_w4a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + input_quant_none = input_quant is None + is_4_bits = weight_quant.num_bits == 4 + is_symmetric = weight_quant.symmetric + is_static = not weight_quant.dynamic + + return is_4_bits and input_quant_none and is_symmetric and is_static + def _get_schema(self, weight_quant: BaseModel, input_quant: BaseModel) -> "CompressedTensorsScheme": + + if self._is_w4a16(weight_quant, input_quant): + return CompressedTensorsW4A16(num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size) + if self._is_static_tensor_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8StaticTensor() @@ -113,8 +139,9 @@ class CompressedTensorsConfig(QuantizationConfig): raise ValueError( f"Could not find quantization details for {layer}.") - return self._get_schema(weight_quant=layer_quant_details["weight"], - input_quant=layer_quant_details["input"]) + return self._get_schema( + weight_quant=layer_quant_details["weights"], + input_quant=layer_quant_details["input_activations"]) class CompressedTensorsLinearMethod(LinearMethodBase): @@ -140,6 +167,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): layer=layer, input_size_per_partition=input_size_per_partition, output_partition_sizes=output_partition_sizes, + input_size=input_size, output_size=output_size, params_dtype=params_dtype, weight_loader=weight_loader) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 9a910f06..dc84d000 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,6 +1,7 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401 CompressedTensorsUnquantized) +from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401 from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 CompressedTensorsW8A8DynamicToken) from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py new file mode 100644 index 00000000..90446a5f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16.py @@ -0,0 +1,168 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState, + marlin_permute_scales) +from vllm.model_executor.utils import set_weight_attrs + +__all__ = ["CompressedTensorsW4A16"] + + +class CompressedTensorsW4A16(CompressedTensorsScheme): + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None): + self.num_bits = num_bits + self.strategy = strategy + self.group_size = group_size + + if self.strategy == "group" and self.group_size is None: + raise ValueError( + "group_size must be given when using strategy group") + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + pack_factor = 32 // self.num_bits + output_size_per_partition = sum(output_partition_sizes) + + if self.group_size is not None: + group_size = self.group_size + else: + group_size = input_size + + weight_scale_dim = None + scales_and_zp_size = input_size // group_size + + if (input_size != input_size_per_partition + and self.group_size is not None): + weight_scale_dim = 1 + scales_and_zp_size = input_size_per_partition // group_size + + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition // pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + + set_weight_attrs( + weight, { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": pack_factor + }) + set_weight_attrs(weight, {"weight_loader": weight_loader}) + + layer.register_parameter("weight_packed", weight) + + weight_scale = Parameter( + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + + set_weight_attrs(weight_scale, {"weight_loader": weight_loader}) + set_weight_attrs(weight_scale, { + "input_dim": weight_scale_dim, + "output_dim": 0 + }) + layer.register_parameter("weight_scale", weight_scale) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = Parameter(torch.empty(2, dtype=torch.int64), + requires_grad=False) + + layer.register_parameter("weight_shape", weight_shape) + set_weight_attrs(weight_shape, {"weight_loader": weight_loader}) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + layer.input_size = input_size + layer.marlin_state = GPTQMarlinState.REPACK + layer.is_k_full = True + layer.group_size = group_size + + max_workspace_size = ( + output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + requires_grad=False) + layer.workspace = workspace + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + reshaped_x = x.reshape(-1, x.shape[-1]) + + size_m = reshaped_x.shape[0] + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + out_shape = x.shape[:-1] + (part_size_n, ) + + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + cur_device = layer.weight_packed.device + + # Reset g_idx related tensors + layer.g_idx = Parameter(torch.empty(0, + dtype=torch.int, + device=cur_device), + requires_grad=False) + layer.g_idx_sort_indices = Parameter(torch.empty( + 0, dtype=torch.int, device=cur_device), + requires_grad=False) + + # Repack weights + marlin_qweight = ops.gptq_marlin_repack( + layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices, + part_size_k, part_size_n, self.num_bits) + + replace_tensor("weight_packed", marlin_qweight) + + # Permute scales + scales_size_k = part_size_k + scales_size_n = part_size_n + + marlin_scales = marlin_permute_scales( + layer.weight_scale.squeeze().t().contiguous(), scales_size_k, + scales_size_n, layer.group_size, self.num_bits) + replace_tensor("weight_scale", marlin_scales) + + output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed, + layer.weight_scale, layer.g_idx, + layer.g_idx_sort_indices, + layer.workspace, self.num_bits, size_m, + part_size_n, part_size_k, + layer.is_k_full) + return output.reshape(out_shape)