From ca77dd7a44f2bc103c668560818918ac0335835a Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Thu, 10 Oct 2024 00:28:08 +0800 Subject: [PATCH] [Hardware][CPU] Support AWQ for CPU backend (#7515) --- .buildkite/run-cpu-test.sh | 10 +- Dockerfile.cpu | 2 +- .../quantization/supported_hardware.rst | 4 +- tests/quantization/test_ipex_quant.py | 28 +++ vllm/model_executor/layers/linear.py | 2 +- .../layers/quantization/__init__.py | 2 + .../layers/quantization/awq_marlin.py | 4 + .../layers/quantization/ipex_quant.py | 166 ++++++++++++++++++ vllm/worker/cpu_worker.py | 3 +- 9 files changed, 214 insertions(+), 7 deletions(-) create mode 100644 tests/quantization/test_ipex_quant.py create mode 100644 vllm/model_executor/layers/quantization/ipex_quant.py diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index c1c471ec..62d3afb0 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -27,13 +27,19 @@ docker exec cpu-test bash -c " pytest -v -s tests/models/decoder_only/language \ --ignore=tests/models/test_fp8.py \ --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_granitemoe.py \ --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # Run compressed-tensor test +# docker exec cpu-test bash -c " +# pytest -s -v \ +# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ +# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" + +# Run AWQ test docker exec cpu-test bash -c " pytest -s -v \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ - tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token" + tests/quantization/test_ipex_quant.py" # online inference docker exec cpu-test bash -c " diff --git a/Dockerfile.cpu b/Dockerfile.cpu index a9d97a3e..1803b386 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +RUN pip install intel_extension_for_pytorch==2.4.0 WORKDIR /workspace diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index ea587e05..9bf0cdb8 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -28,7 +28,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✗ - ✗ - - ✗ + - ✅︎ - ✗ - ✗ * - GPTQ @@ -61,7 +61,7 @@ The table below shows the compatibility of various quantization implementations - ✅︎ - ✗ - ✗ - - ✗ + - ✅︎ - ✗ - ✗ * - FP8 (W8A8) diff --git a/tests/quantization/test_ipex_quant.py b/tests/quantization/test_ipex_quant.py new file mode 100644 index 00000000..d541efce --- /dev/null +++ b/tests/quantization/test_ipex_quant.py @@ -0,0 +1,28 @@ +"""Test model set-up and inference for quantized HF models supported + on the CPU backend using IPEX (including AWQ). + + Validating the configuration and printing results for manual checking. + + Run `pytest tests/quantization/test_ipex_quant.py`. +""" + +import pytest + +from vllm.platforms import current_platform + +MODELS = [ + "casperhansen/llama-3-8b-instruct-awq", +] +DTYPE = ["bfloat16"] + + +@pytest.mark.skipif(not current_platform.is_cpu(), + reason="only supports the CPU backend.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", DTYPE) +def test_ipex_quant(vllm_runner, model, dtype): + with vllm_runner(model, dtype=dtype) as llm: + output = llm.generate_greedy(["The capital of France is"], + max_tokens=32) + assert output + print(output) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a3d1dc2c..94f30412 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", - "ModelOptFp8LinearMethod" + "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod" ] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3c38f0a0..da841d05 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.neuron_quant import ( @@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, } diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 294fe118..b3d93b28 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -123,6 +124,9 @@ class AWQMarlinConfig(QuantizationConfig): group_size = quant_config.get("group_size") has_zp = quant_config.get("zero_point") + if not current_platform.is_cuda(): + return False + if quant_method != "awq": return False diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py new file mode 100644 index 00000000..e5405263 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -0,0 +1,166 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.platforms import current_platform + + +class IPEXConfig(QuantizationConfig): + """INT8 quantization config class using IPEX for the CPU backend, + including AWQ. + """ + + IPEX_QUANT_METHOD_MAP = { + "awq": 1, + "gptq": 2, + } + + def __init__( + self, + method: str, + weight_bits: int, + group_size: int, + ) -> None: + self.method = method + self.weight_bits = weight_bits + self.group_size = group_size + self.pack_factor = 32 // self.weight_bits + + if self.weight_bits not in [4]: + raise ValueError(f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}.") + + if self.method == "awq": + self.quant_method = IPEXAWQLinearMethod + else: + raise ValueError(f"IPEX quantization supports [awq], " + f"but got {self.method}.") + + def __repr__(self) -> str: + return (f"IPEXConfig(method={self.method}" + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size}") + + def get_ipex_quant_method_id(self) -> int: + return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method] + + @classmethod + def get_name(cls) -> str: + return "ipex" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": + method = cls.get_from_keys(config, ["quant_method"]).lower() + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + return cls(method, weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if not current_platform.is_cpu(): + return None + + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + if quant_method in ["awq"]: + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + return self.quant_method(self) + return None + + def get_scaled_act_names(self) -> List[str]: + if self.method == "awq": + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + else: + return [] + + +class IPEXAWQLinearMethod(AWQLinearMethod): + """AWQ linear method using IPEX for the CPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer=layer) + + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < "2.4.0": + raise ImportError("intel_extension_for_pytorch version is " + "wrong. Please install " + "intel_extension_for_pytorch>=2.4.0.") + except ImportError as err: + raise ImportError( + "Please install " + "intel_extension_for_pytorch>=2.4.0 via " + "`pip install intel_extension_for_pytorch>=2.4.0`" + " to use IPEX-AWQ linear method.") from err + + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + + layer.ipex_output_size = layer.qweight.size( + 1) * self.quant_config.pack_factor + layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\ + WeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method= + self.quant_config.get_ipex_quant_method_id() # type: ignore + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 7384ffcb..d6e3670e 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -215,7 +215,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): def init_device(self) -> None: if self.local_omp_cpuid != "all": ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) - logger.info(ret) + if ret: + logger.info(ret) self.init_distributed_environment() # Set random seed.