[Hardware][CPU] Support AWQ for CPU backend (#7515)
This commit is contained in:
parent
7dea289066
commit
ca77dd7a44
@ -27,13 +27,19 @@ docker exec cpu-test bash -c "
|
|||||||
pytest -v -s tests/models/decoder_only/language \
|
pytest -v -s tests/models/decoder_only/language \
|
||||||
--ignore=tests/models/test_fp8.py \
|
--ignore=tests/models/test_fp8.py \
|
||||||
--ignore=tests/models/decoder_only/language/test_jamba.py \
|
--ignore=tests/models/decoder_only/language/test_jamba.py \
|
||||||
|
--ignore=tests/models/decoder_only/language/test_granitemoe.py \
|
||||||
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
--ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||||
|
|
||||||
# Run compressed-tensor test
|
# Run compressed-tensor test
|
||||||
|
# docker exec cpu-test bash -c "
|
||||||
|
# pytest -s -v \
|
||||||
|
# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||||
|
# tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
|
||||||
|
|
||||||
|
# Run AWQ test
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
pytest -s -v \
|
pytest -s -v \
|
||||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
tests/quantization/test_ipex_quant.py"
|
||||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynanmic_per_token"
|
|
||||||
|
|
||||||
# online inference
|
# online inference
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
|
|||||||
@ -22,7 +22,7 @@ ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/li
|
|||||||
|
|
||||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||||
|
|
||||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
|
RUN pip install intel_extension_for_pytorch==2.4.0
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ The table below shows the compatibility of various quantization implementations
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✅︎
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
* - GPTQ
|
* - GPTQ
|
||||||
@ -61,7 +61,7 @@ The table below shows the compatibility of various quantization implementations
|
|||||||
- ✅︎
|
- ✅︎
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✅︎
|
||||||
- ✗
|
- ✗
|
||||||
- ✗
|
- ✗
|
||||||
* - FP8 (W8A8)
|
* - FP8 (W8A8)
|
||||||
|
|||||||
28
tests/quantization/test_ipex_quant.py
Normal file
28
tests/quantization/test_ipex_quant.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
"""Test model set-up and inference for quantized HF models supported
|
||||||
|
on the CPU backend using IPEX (including AWQ).
|
||||||
|
|
||||||
|
Validating the configuration and printing results for manual checking.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_ipex_quant.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
MODELS = [
|
||||||
|
"casperhansen/llama-3-8b-instruct-awq",
|
||||||
|
]
|
||||||
|
DTYPE = ["bfloat16"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cpu(),
|
||||||
|
reason="only supports the CPU backend.")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPE)
|
||||||
|
def test_ipex_quant(vllm_runner, model, dtype):
|
||||||
|
with vllm_runner(model, dtype=dtype) as llm:
|
||||||
|
output = llm.generate_greedy(["The capital of France is"],
|
||||||
|
max_tokens=32)
|
||||||
|
assert output
|
||||||
|
print(output)
|
||||||
@ -27,7 +27,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||||
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
||||||
"ModelOptFp8LinearMethod"
|
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|||||||
GPTQMarlinConfig)
|
GPTQMarlinConfig)
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQMarlin24Config)
|
GPTQMarlin24Config)
|
||||||
|
from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
|
from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config
|
||||||
from vllm.model_executor.layers.quantization.neuron_quant import (
|
from vllm.model_executor.layers.quantization.neuron_quant import (
|
||||||
@ -49,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|||||||
"qqq": QQQConfig,
|
"qqq": QQQConfig,
|
||||||
"experts_int8": ExpertsInt8Config,
|
"experts_int8": ExpertsInt8Config,
|
||||||
"neuron_quant": NeuronQuantConfig,
|
"neuron_quant": NeuronQuantConfig,
|
||||||
|
"ipex": IPEXConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
PackedvLLMParameter)
|
PackedvLLMParameter)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -123,6 +124,9 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
group_size = quant_config.get("group_size")
|
group_size = quant_config.get("group_size")
|
||||||
has_zp = quant_config.get("zero_point")
|
has_zp = quant_config.get("zero_point")
|
||||||
|
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
|
||||||
if quant_method != "awq":
|
if quant_method != "awq":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
166
vllm/model_executor/layers/quantization/ipex_quant.py
Normal file
166
vllm/model_executor/layers/quantization/ipex_quant.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
class IPEXConfig(QuantizationConfig):
|
||||||
|
"""INT8 quantization config class using IPEX for the CPU backend,
|
||||||
|
including AWQ.
|
||||||
|
"""
|
||||||
|
|
||||||
|
IPEX_QUANT_METHOD_MAP = {
|
||||||
|
"awq": 1,
|
||||||
|
"gptq": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
) -> None:
|
||||||
|
self.method = method
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
if self.weight_bits not in [4]:
|
||||||
|
raise ValueError(f"IPEX quantization supports weight bits [4], "
|
||||||
|
f"but got {self.weight_bits}.")
|
||||||
|
|
||||||
|
if self.method == "awq":
|
||||||
|
self.quant_method = IPEXAWQLinearMethod
|
||||||
|
else:
|
||||||
|
raise ValueError(f"IPEX quantization supports [awq], "
|
||||||
|
f"but got {self.method}.")
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"IPEXConfig(method={self.method}"
|
||||||
|
f"weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}")
|
||||||
|
|
||||||
|
def get_ipex_quant_method_id(self) -> int:
|
||||||
|
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "ipex"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_config_filenames() -> List[str]:
|
||||||
|
return [
|
||||||
|
"quant_config.json",
|
||||||
|
"quantize_config.json",
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
|
||||||
|
method = cls.get_from_keys(config, ["quant_method"]).lower()
|
||||||
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||||
|
return cls(method, weight_bits, group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
if not current_platform.is_cpu():
|
||||||
|
return None
|
||||||
|
|
||||||
|
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||||
|
|
||||||
|
if quant_method in ["awq"]:
|
||||||
|
return cls.get_name()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["LinearMethodBase"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return self.quant_method(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
if self.method == "awq":
|
||||||
|
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||||
|
"""AWQ linear method using IPEX for the CPU backend.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: IPEXConfig):
|
||||||
|
self.quant_config = quant_config # type: ignore
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
super().process_weights_after_loading(layer=layer)
|
||||||
|
|
||||||
|
bias = layer.bias if not layer.skip_bias_add else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
if ipex.__version__ < "2.4.0":
|
||||||
|
raise ImportError("intel_extension_for_pytorch version is "
|
||||||
|
"wrong. Please install "
|
||||||
|
"intel_extension_for_pytorch>=2.4.0.")
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install "
|
||||||
|
"intel_extension_for_pytorch>=2.4.0 via "
|
||||||
|
"`pip install intel_extension_for_pytorch>=2.4.0`"
|
||||||
|
" to use IPEX-AWQ linear method.") from err
|
||||||
|
|
||||||
|
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
|
||||||
|
# with better performance.
|
||||||
|
lowp_mode = ipex.quantization.WoqLowpMode.INT8
|
||||||
|
# The weight will be de-packed from INT4 to INT8.
|
||||||
|
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
|
||||||
|
# The float activation will be quantized (dynamic, per-token) to INT8.
|
||||||
|
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH
|
||||||
|
|
||||||
|
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
|
||||||
|
weight_dtype=weight_dtype,
|
||||||
|
lowp_mode=lowp_mode,
|
||||||
|
act_quant_mode=act_quant_mode,
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.ipex_output_size = layer.qweight.size(
|
||||||
|
1) * self.quant_config.pack_factor
|
||||||
|
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\
|
||||||
|
WeightOnlyQuantizedLinear.from_weight(
|
||||||
|
layer.qweight,
|
||||||
|
layer.scales,
|
||||||
|
layer.qzeros,
|
||||||
|
layer.qweight.size(0),
|
||||||
|
layer.ipex_output_size,
|
||||||
|
qconfig=qconfig,
|
||||||
|
bias=bias,
|
||||||
|
group_size=self.quant_config.group_size,
|
||||||
|
quant_method=
|
||||||
|
self.quant_config.get_ipex_quant_method_id() # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = layer.ipex_qlinear(reshaped_x)
|
||||||
|
|
||||||
|
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
|
||||||
@ -215,7 +215,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
|||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.local_omp_cpuid != "all":
|
if self.local_omp_cpuid != "all":
|
||||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||||
logger.info(ret)
|
if ret:
|
||||||
|
logger.info(ret)
|
||||||
|
|
||||||
self.init_distributed_environment()
|
self.init_distributed_environment()
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user