[Misc] Add ignored layers for fp8 quantization (#6657)
This commit is contained in:
parent
38c4b7e863
commit
0eb0757bef
@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
FUSED_LAYER_NAME_MAPPING)
|
||||
|
||||
|
||||
class CompressionFormat(Enum):
|
||||
dense = "dense"
|
||||
@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
# fused_name: List[shard_name]
|
||||
_FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
def should_ignore_layer(layer_name: Optional[str],
|
||||
ignore: Iterable[str]) -> bool:
|
||||
if layer_name is None:
|
||||
@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
||||
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
|
||||
@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_channel_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -18,14 +20,6 @@ from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
_FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
class FBGEMMFp8Config(QuantizationConfig):
|
||||
"""Config class for FBGEMM Fp8."""
|
||||
@ -62,37 +56,10 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||
|
||||
def _is_layer_skipped(self, prefix: str) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = shard_prefix in self.ignore_list
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = prefix in self.ignore_list
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self._is_layer_skipped(prefix):
|
||||
if is_layer_skipped(prefix, self.ignore_list):
|
||||
return UnquantizedLinearMethod()
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
@ -8,12 +8,15 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
fused_moe)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||
@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
|
||||
raise ValueError(
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
self.ignored_layers = ignored_layers or []
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme)
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped(prefix, self.ignored_layers):
|
||||
return UnquantizedLinearMethod()
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
|
||||
@ -1,10 +1,48 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
SUPPORTED_NUM_BITS = [4, 8]
|
||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# Note: this is a hack. We should update each model to register the
|
||||
# stacked params and get it from there instead in a future PR.
|
||||
# fused_name: List[shard_name]
|
||||
FUSED_LAYER_NAME_MAPPING = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
}
|
||||
|
||||
|
||||
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = shard_prefix in ignored_layers
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision.")
|
||||
else:
|
||||
is_skipped = prefix in ignored_layers
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user