[Misc] Update gptq_marlin_24 to use vLLMParameters (#7762)

Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-08-26 17:44:54 -04:00 committed by GitHub
parent 665304092d
commit dd9857f5fa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 54 deletions

View File

@ -23,7 +23,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod" "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
] ]

View File

@ -8,7 +8,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__) logger = init_logger(__name__)
@ -149,7 +152,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
): ):
del output_size # Unused. del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16: if params_dtype != torch.float16:
raise ValueError( raise ValueError(
f"The params dtype must be float16, but got {params_dtype}") f"The params dtype must be float16, but got {params_dtype}")
@ -187,87 +190,80 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu") "Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32. # Quantized 4Bit weights packed into Int32.
qweight = Parameter( qweight = PackedvLLMParameter(
torch.empty( data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2, input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size // output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor, self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qweight, packed_factor=self.quant_config.pack_factor,
{ marlin_tile_size=self.quant_config.tile_size,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Meta # Meta
meta = Parameter( meta = PackedvLLMParameter(data=torch.empty(
torch.empty(
input_size_per_partition // 8 // 2 // 2, input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2, output_size_per_partition * 2,
device="cuda", device="cuda",
dtype=torch.int16, dtype=torch.int16,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
meta, packed_factor=1,
{ marlin_tile_size=2,
"input_dim": 0, weight_loader=weight_loader)
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
},
)
# Determine if channelwise or not # Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition // input_size_per_partition //
self.quant_config.group_size) self.quant_config.group_size)
scales = Parameter( weight_scale_args = {
"data":
torch.empty( torch.empty(
input_groups, input_groups,
output_size_per_partition, output_size_per_partition,
device="cuda", device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, "weight_loader":
) weight_loader
set_weight_attrs( }
scales, if input_groups == 1:
{ scales = ChannelQuantScaleParameter(output_dim=1,
"input_dim": None if input_groups == 1 else 0, **weight_scale_args)
"output_dim": 1, else:
}, scales = GroupQuantScaleParameter(output_dim=1,
) input_dim=0,
**weight_scale_args)
# Allocate workspace (Used for internal locking mechanism) # Allocate workspace (Used for internal locking mechanism)
max_workspace_size = ( max_workspace_size = (
output_size_per_partition // output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda", device="cuda",
dtype=torch.int), dtype=torch.int),
requires_grad=False) weight_loader=weight_loader)
layer.register_parameter("B_24", qweight) layer.register_parameter("B_24", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("B_meta", meta) layer.register_parameter("B_meta", meta)
set_weight_attrs(meta, extra_weight_attrs)
layer.register_parameter("s", scales) layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace) layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply( def apply(
self, self,