[Misc] Update GPTQ to use vLLMParameters (#7976)
This commit is contained in:
parent
dc0b6066ab
commit
2188a60c7e
@ -4,6 +4,12 @@ gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
|
|||||||
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
|
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
|
||||||
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
|
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
|
||||||
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
|
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
|
||||||
|
gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
|
||||||
|
gptq, TheBloke/Llama-2-7B-GPTQ, main
|
||||||
|
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
|
||||||
|
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
|
||||||
|
gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
|
||||||
|
gptq, TechxGenus/gemma-1.1-2b-it-GPTQ, main
|
||||||
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
|
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
|
||||||
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
|
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
|
||||||
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
|
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
MAX_MODEL_LEN = 1024
|
||||||
MODEL_NAME = os.environ.get("MODEL_NAME",
|
MODEL_NAME = os.environ.get("MODEL_NAME",
|
||||||
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
|
||||||
@ -8,9 +10,12 @@ QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
|
|||||||
|
|
||||||
|
|
||||||
def test_weight_loading(vllm_runner):
|
def test_weight_loading(vllm_runner):
|
||||||
|
"""
|
||||||
|
Test parameter weight loading with tp>1.
|
||||||
|
"""
|
||||||
with vllm_runner(model_name=MODEL_NAME,
|
with vllm_runner(model_name=MODEL_NAME,
|
||||||
revision=REVISION,
|
revision=REVISION,
|
||||||
dtype="auto",
|
dtype=torch.half if QUANTIZATION == "gptq" else "auto",
|
||||||
quantization=QUANTIZATION,
|
quantization=QUANTIZATION,
|
||||||
max_model_len=MAX_MODEL_LEN,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
tensor_parallel_size=2) as model:
|
tensor_parallel_size=2) as model:
|
||||||
|
|||||||
@ -14,8 +14,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
PackedColumnParameter,
|
||||||
PackedvLLMParameter,
|
PackedvLLMParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter,
|
||||||
|
RowvLLMParameter)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -24,7 +26,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||||
"TPUInt8LinearMethod"
|
"TPUInt8LinearMethod", "GPTQLinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -574,8 +576,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
# Special case for Quantization.
|
# Special case for Quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
if isinstance(param, PackedvLLMParameter
|
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
|
||||||
) and param.packed_dim == param.output_dim:
|
)) and param.packed_dim == param.output_dim:
|
||||||
shard_size, shard_offset = \
|
shard_size, shard_offset = \
|
||||||
param.adjust_shard_indexes_for_packing(
|
param.adjust_shard_indexes_for_packing(
|
||||||
shard_size=shard_size, shard_offset=shard_offset)
|
shard_size=shard_size, shard_offset=shard_offset)
|
||||||
@ -594,9 +596,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||||
shard_id=0)
|
shard_id=0)
|
||||||
return
|
return
|
||||||
elif type(param) is BasevLLMParameter:
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
||||||
return
|
return
|
||||||
|
# TODO: @dsikka - move to parameter.py
|
||||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -724,8 +727,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
# Special case for Quantization.
|
# Special case for Quantization.
|
||||||
# If quantized, we need to adjust the offset and size to account
|
# If quantized, we need to adjust the offset and size to account
|
||||||
# for the packing.
|
# for the packing.
|
||||||
if isinstance(param, PackedvLLMParameter
|
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
|
||||||
) and param.packed_dim == param.output_dim:
|
)) and param.packed_dim == param.output_dim:
|
||||||
shard_size, shard_offset = \
|
shard_size, shard_offset = \
|
||||||
param.adjust_shard_indexes_for_packing(
|
param.adjust_shard_indexes_for_packing(
|
||||||
shard_size=shard_size, shard_offset=shard_offset)
|
shard_size=shard_size, shard_offset=shard_offset)
|
||||||
@ -741,12 +744,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_shard_id: Optional[str] = None):
|
loaded_shard_id: Optional[str] = None):
|
||||||
if loaded_shard_id is None: # special case for certain models
|
if loaded_shard_id is None: # special case for certain models
|
||||||
if isinstance(param, PerTensorScaleParameter):
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
|
||||||
shard_id=0)
|
|
||||||
return
|
return
|
||||||
elif type(param) is BasevLLMParameter:
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight)
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
||||||
return
|
return
|
||||||
|
# TODO: @dsikka - move to parameter.py
|
||||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
|
GroupQuantScaleParameter,
|
||||||
|
PackedColumnParameter,
|
||||||
|
PackedvLLMParameter,
|
||||||
|
RowvLLMParameter)
|
||||||
|
|
||||||
|
|
||||||
class GPTQConfig(QuantizationConfig):
|
class GPTQConfig(QuantizationConfig):
|
||||||
@ -108,6 +112,7 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
del output_size # Unused.
|
del output_size # Unused.
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The input size is not aligned with the quantized "
|
"The input size is not aligned with the quantized "
|
||||||
@ -138,73 +143,81 @@ class GPTQLinearMethod(LinearMethodBase):
|
|||||||
scale_and_zero_size = input_size_per_partition // group_size
|
scale_and_zero_size = input_size_per_partition // group_size
|
||||||
scale_and_zero_input_dim = 0
|
scale_and_zero_input_dim = 0
|
||||||
|
|
||||||
qweight = Parameter(
|
qweight = PackedvLLMParameter(
|
||||||
torch.empty(
|
data=torch.empty(
|
||||||
input_size_per_partition // self.quant_config.pack_factor,
|
input_size_per_partition // self.quant_config.pack_factor,
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=0,
|
||||||
qweight, {
|
packed_factor=self.quant_config.pack_factor,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 0,
|
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
[
|
||||||
})
|
i // self.quant_config.group_size
|
||||||
g_idx = Parameter(
|
for i in range(input_size_per_partition)
|
||||||
torch.tensor(
|
],
|
||||||
[
|
dtype=torch.int32,
|
||||||
i // self.quant_config.group_size
|
),
|
||||||
for i in range(input_size_per_partition)
|
input_dim=0,
|
||||||
],
|
weight_loader=weight_loader)
|
||||||
dtype=torch.int32,
|
qzeros_args = {
|
||||||
),
|
"data":
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
# Ignore warning from fused linear layers such as QKVParallelLinear.
|
|
||||||
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
|
|
||||||
qzeros = Parameter(
|
|
||||||
torch.empty(
|
torch.empty(
|
||||||
scale_and_zero_size,
|
scale_and_zero_size,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
"weight_loader":
|
||||||
)
|
weight_loader
|
||||||
set_weight_attrs(
|
}
|
||||||
qzeros, {
|
weight_scale_args = {
|
||||||
"input_dim": scale_and_zero_input_dim,
|
"data":
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
|
||||||
})
|
|
||||||
scales = Parameter(
|
|
||||||
torch.empty(
|
torch.empty(
|
||||||
scale_and_zero_size,
|
scale_and_zero_size,
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
"weight_loader":
|
||||||
)
|
weight_loader
|
||||||
set_weight_attrs(scales, {
|
}
|
||||||
"input_dim": scale_and_zero_input_dim,
|
if scale_and_zero_input_dim is None:
|
||||||
"output_dim": 1,
|
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||||
})
|
**weight_scale_args)
|
||||||
|
qzeros = PackedColumnParameter(
|
||||||
|
output_dim=1,
|
||||||
|
packed_dim=1,
|
||||||
|
packed_factor=self.quant_config.pack_factor,
|
||||||
|
**qzeros_args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
scales = GroupQuantScaleParameter(output_dim=1,
|
||||||
|
input_dim=0,
|
||||||
|
**weight_scale_args)
|
||||||
|
qzeros = PackedvLLMParameter(
|
||||||
|
input_dim=0,
|
||||||
|
output_dim=1,
|
||||||
|
packed_dim=1,
|
||||||
|
packed_factor=self.quant_config.pack_factor,
|
||||||
|
**qzeros_args)
|
||||||
|
|
||||||
layer.register_parameter("qweight", qweight)
|
layer.register_parameter("qweight", qweight)
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
|
||||||
layer.register_parameter("g_idx", g_idx)
|
layer.register_parameter("g_idx", g_idx)
|
||||||
set_weight_attrs(g_idx, extra_weight_attrs)
|
|
||||||
layer.register_parameter("qzeros", qzeros)
|
layer.register_parameter("qzeros", qzeros)
|
||||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
|
||||||
layer.register_parameter("scales", scales)
|
layer.register_parameter("scales", scales)
|
||||||
set_weight_attrs(scales, extra_weight_attrs)
|
|
||||||
|
|
||||||
layer.exllama_state = exllama_state
|
layer.exllama_state = exllama_state
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# for torch.compile
|
||||||
|
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||||
|
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
|
||||||
|
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||||
|
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
|
||||||
|
|
||||||
# exllama needs to shuffle the weight after the weight is loaded
|
# exllama needs to shuffle the weight after the weight is loaded
|
||||||
# here we do the shuffle on first forward pass
|
# here we do the shuffle on first forward pass
|
||||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
|
||||||
|
from vllm.model_executor.parameter import BasevLLMParameter
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
@ -370,10 +371,12 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
# If param packed on the same dim we are sharding on, then
|
# If param packed on the same dim we are sharding on, then
|
||||||
# need to adjust offsets of loaded weight by pack_factor.
|
# need to adjust offsets of loaded weight by pack_factor.
|
||||||
if packed_dim is not None and packed_dim == output_dim:
|
if packed_dim is not None and packed_dim == output_dim:
|
||||||
|
packed_factor = param.packed_factor if isinstance(
|
||||||
|
param, BasevLLMParameter) else param.pack_factor
|
||||||
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
|
||||||
param.pack_factor)
|
param.packed_factor)
|
||||||
start_idx = start_idx // param.pack_factor
|
start_idx = start_idx // packed_factor
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // packed_factor
|
||||||
else:
|
else:
|
||||||
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
assert loaded_weight.shape[output_dim] == self.org_vocab_size
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from fractions import Fraction
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
packed_factor: int,
|
packed_factor: Union[int, Fraction],
|
||||||
packed_dim: int,
|
packed_dim: int,
|
||||||
marlin_tile_size: Optional[int] = None,
|
marlin_tile_size: Optional[int] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
packed_factor: int,
|
packed_factor: Union[int, Fraction],
|
||||||
packed_dim: int,
|
packed_dim: int,
|
||||||
marlin_tile_size: Optional[int] = None,
|
marlin_tile_size: Optional[int] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user