[Misc] Update gptq_marlin_24 to use vLLMParameters (#7762)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
665304092d
commit
dd9857f5fa
@ -23,7 +23,7 @@ logger = init_logger(__name__)
|
|||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||||
"MarlinLinearMethod", "QQQLinearMethod"
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
GroupQuantScaleParameter,
|
||||||
|
PackedvLLMParameter)
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -149,7 +152,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
|||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
del output_size # Unused.
|
del output_size # Unused.
|
||||||
|
weight_loader = extra_weight_attrs["weight_loader"]
|
||||||
if params_dtype != torch.float16:
|
if params_dtype != torch.float16:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The params dtype must be float16, but got {params_dtype}")
|
f"The params dtype must be float16, but got {params_dtype}")
|
||||||
@ -187,87 +190,80 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
|
|||||||
"Each permutation group must reside on the same gpu")
|
"Each permutation group must reside on the same gpu")
|
||||||
|
|
||||||
# Quantized 4Bit weights packed into Int32.
|
# Quantized 4Bit weights packed into Int32.
|
||||||
qweight = Parameter(
|
qweight = PackedvLLMParameter(
|
||||||
torch.empty(
|
data=torch.empty(
|
||||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||||
output_size_per_partition * self.quant_config.tile_size //
|
output_size_per_partition * self.quant_config.tile_size //
|
||||||
self.quant_config.pack_factor,
|
self.quant_config.pack_factor,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
input_dim=0,
|
||||||
)
|
output_dim=1,
|
||||||
set_weight_attrs(
|
packed_dim=1,
|
||||||
qweight,
|
packed_factor=self.quant_config.pack_factor,
|
||||||
{
|
marlin_tile_size=self.quant_config.tile_size,
|
||||||
"input_dim": 0,
|
weight_loader=weight_loader)
|
||||||
"output_dim": 1,
|
|
||||||
"packed_dim": 1,
|
|
||||||
"pack_factor": self.quant_config.pack_factor,
|
|
||||||
"marlin_tile_size": self.quant_config.tile_size,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Meta
|
# Meta
|
||||||
meta = Parameter(
|
meta = PackedvLLMParameter(data=torch.empty(
|
||||||
torch.empty(
|
input_size_per_partition // 8 // 2 // 2,
|
||||||
input_size_per_partition // 8 // 2 // 2,
|
output_size_per_partition * 2,
|
||||||
output_size_per_partition * 2,
|
device="cuda",
|
||||||
device="cuda",
|
dtype=torch.int16,
|
||||||
dtype=torch.int16,
|
),
|
||||||
),
|
input_dim=0,
|
||||||
requires_grad=False,
|
output_dim=1,
|
||||||
)
|
packed_dim=1,
|
||||||
set_weight_attrs(
|
packed_factor=1,
|
||||||
meta,
|
marlin_tile_size=2,
|
||||||
{
|
weight_loader=weight_loader)
|
||||||
"input_dim": 0,
|
|
||||||
"packed_dim": 1,
|
|
||||||
"pack_factor": 1,
|
|
||||||
"output_dim": 1,
|
|
||||||
"marlin_tile_size": 2,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine if channelwise or not
|
# Determine if channelwise or not
|
||||||
input_groups = (1 if self.quant_config.group_size == -1 else
|
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||||
input_size_per_partition //
|
input_size_per_partition //
|
||||||
self.quant_config.group_size)
|
self.quant_config.group_size)
|
||||||
|
|
||||||
scales = Parameter(
|
weight_scale_args = {
|
||||||
|
"data":
|
||||||
torch.empty(
|
torch.empty(
|
||||||
input_groups,
|
input_groups,
|
||||||
output_size_per_partition,
|
output_size_per_partition,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=params_dtype,
|
dtype=params_dtype,
|
||||||
),
|
),
|
||||||
requires_grad=False,
|
"weight_loader":
|
||||||
)
|
weight_loader
|
||||||
set_weight_attrs(
|
}
|
||||||
scales,
|
if input_groups == 1:
|
||||||
{
|
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||||
"input_dim": None if input_groups == 1 else 0,
|
**weight_scale_args)
|
||||||
"output_dim": 1,
|
else:
|
||||||
},
|
scales = GroupQuantScaleParameter(output_dim=1,
|
||||||
)
|
input_dim=0,
|
||||||
|
**weight_scale_args)
|
||||||
|
|
||||||
# Allocate workspace (Used for internal locking mechanism)
|
# Allocate workspace (Used for internal locking mechanism)
|
||||||
max_workspace_size = (
|
max_workspace_size = (
|
||||||
output_size_per_partition //
|
output_size_per_partition //
|
||||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||||
workspace = Parameter(torch.zeros(max_workspace_size,
|
|
||||||
device="cuda",
|
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||||
dtype=torch.int),
|
device="cuda",
|
||||||
requires_grad=False)
|
dtype=torch.int),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
layer.register_parameter("B_24", qweight)
|
layer.register_parameter("B_24", qweight)
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
|
||||||
layer.register_parameter("B_meta", meta)
|
layer.register_parameter("B_meta", meta)
|
||||||
set_weight_attrs(meta, extra_weight_attrs)
|
|
||||||
layer.register_parameter("s", scales)
|
layer.register_parameter("s", scales)
|
||||||
set_weight_attrs(scales, extra_weight_attrs)
|
|
||||||
layer.register_parameter("workspace", workspace)
|
layer.register_parameter("workspace", workspace)
|
||||||
set_weight_attrs(workspace, extra_weight_attrs)
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# required by torch.compile
|
||||||
|
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||||
|
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||||
|
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||||
|
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user