[Misc] Update awq and awq_marlin to use vLLMParameters (#7422)
This commit is contained in:
parent
d3bdfd3ab9
commit
b1e5afc3e7
@ -12,4 +12,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
|
||||
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
||||
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
|
||||
@ -21,7 +21,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
|
||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
@ -101,55 +101,51 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -151,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
@ -164,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
num_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
@ -228,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||
requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||
requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||
requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
@ -278,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
bias=bias)
|
||||
bias=bias)
|
||||
Loading…
Reference in New Issue
Block a user