[Misc] Update awq and awq_marlin to use vLLMParameters (#7422)
This commit is contained in:
parent
d3bdfd3ab9
commit
b1e5afc3e7
@ -12,4 +12,6 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
|
|||||||
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
|
||||||
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
|
||||||
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||||
|
awq, casperhansen/mixtral-instruct-awq, main
|
||||||
|
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||||
|
|||||||
@ -21,7 +21,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod"
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
|
"AWQLinearMethod", "GPTQMarlinLinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,13 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
|
PackedvLLMParameter)
|
||||||
|
|
||||||
|
|
||||||
class AWQConfig(QuantizationConfig):
|
class AWQConfig(QuantizationConfig):
|
||||||
@ -101,55 +101,51 @@ class AWQLinearMethod(LinearMethodBase):
|
|||||||
"weight shape. This can be caused by too large "
|
"weight shape. This can be caused by too large "
|
||||||
"tensor parallel size.")
|
"tensor parallel size.")
|
||||||
|
|
||||||
qweight = Parameter(
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
torch.empty(
|
qweight = PackedvLLMParameter(
|
||||||
|
data=torch.empty(
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=1,
|
||||||
qweight, {
|
packed_factor=self.quant_config.pack_factor,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
qzeros = PackedvLLMParameter(
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
data=torch.empty(
|
||||||
})
|
|
||||||
qzeros = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
input_size_per_partition // self.quant_config.group_size,
|
input_size_per_partition // self.quant_config.group_size,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=1,
|
||||||
qzeros, {
|
packed_factor=self.quant_config.pack_factor,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
input_size_per_partition // self.quant_config.group_size,
|
||||||
})
|
output_size_per_partition,
|
||||||
scales = Parameter(
|
dtype=params_dtype,
|
||||||
torch.empty(
|
),
|
||||||
input_size_per_partition // self.quant_config.group_size,
|
input_dim=0,
|
||||||
output_size_per_partition,
|
output_dim=1,
|
||||||
dtype=params_dtype,
|
weight_loader=weight_loader)
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(scales, {
|
|
||||||
"input_dim": 0,
|
|
||||||
"output_dim": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
layer.register_parameter("qweight", qweight)
|
layer.register_parameter("qweight", qweight)
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
|
||||||
layer.register_parameter("qzeros", qzeros)
|
layer.register_parameter("qzeros", qzeros)
|
||||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
|
||||||
layer.register_parameter("scales", scales)
|
layer.register_parameter("scales", scales)
|
||||||
set_weight_attrs(scales, extra_weight_attrs)
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
|
|||||||
@ -1,12 +1,10 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
set_weight_attrs)
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
@ -14,6 +12,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|||||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||||
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
|
PackedvLLMParameter)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -151,6 +151,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
del output_size
|
del output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
# Normalize group_size
|
# Normalize group_size
|
||||||
if self.quant_config.group_size != -1:
|
if self.quant_config.group_size != -1:
|
||||||
@ -164,59 +165,44 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
group_size=group_size)
|
group_size=group_size)
|
||||||
|
|
||||||
qweight = Parameter(
|
qweight = PackedvLLMParameter(
|
||||||
torch.empty(
|
data=torch.empty(
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=1,
|
||||||
qweight, {
|
packed_factor=self.quant_config.pack_factor,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
|
||||||
})
|
|
||||||
|
|
||||||
num_groups = input_size_per_partition // group_size
|
num_groups = input_size_per_partition // group_size
|
||||||
|
|
||||||
qzeros = Parameter(
|
qzeros = PackedvLLMParameter(
|
||||||
torch.empty(
|
data=torch.empty(
|
||||||
num_groups,
|
num_groups,
|
||||||
output_size_per_partition // self.quant_config.pack_factor,
|
output_size_per_partition // self.quant_config.pack_factor,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=1,
|
||||||
qzeros, {
|
packed_factor=self.quant_config.pack_factor,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
|
||||||
})
|
|
||||||
|
|
||||||
scales = Parameter(
|
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||||
torch.empty(
|
num_groups,
|
||||||
num_groups,
|
output_size_per_partition,
|
||||||
output_size_per_partition,
|
dtype=params_dtype,
|
||||||
dtype=params_dtype,
|
),
|
||||||
),
|
input_dim=0,
|
||||||
requires_grad=False,
|
output_dim=1,
|
||||||
)
|
weight_loader=weight_loader)
|
||||||
set_weight_attrs(scales, {
|
|
||||||
"input_dim": 0,
|
|
||||||
"output_dim": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
layer.register_parameter("qweight", qweight)
|
layer.register_parameter("qweight", qweight)
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
|
||||||
layer.register_parameter("qzeros", qzeros)
|
layer.register_parameter("qzeros", qzeros)
|
||||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
|
||||||
layer.register_parameter("scales", scales)
|
layer.register_parameter("scales", scales)
|
||||||
set_weight_attrs(scales, extra_weight_attrs)
|
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
@ -228,6 +214,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
# Here, we handle the repacking
|
# Here, we handle the repacking
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
device = layer.qweight.device
|
device = layer.qweight.device
|
||||||
|
layer.qweight = torch.nn.Parameter(layer.qweight.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.qzeros = torch.nn.Parameter(layer.qzeros.data,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.scales = torch.nn.Parameter(layer.scales.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
# Allocate marlin workspace
|
# Allocate marlin workspace
|
||||||
layer.workspace = marlin_make_workspace(
|
layer.workspace = marlin_make_workspace(
|
||||||
@ -278,4 +270,4 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
quant_type=self.quant_config.quant_type,
|
quant_type=self.quant_config.quant_type,
|
||||||
output_size_per_partition=layer.output_size_per_partition,
|
output_size_per_partition=layer.output_size_per_partition,
|
||||||
input_size_per_partition=layer.input_size_per_partition,
|
input_size_per_partition=layer.input_size_per_partition,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
Loading…
Reference in New Issue
Block a user