[ Misc ] More Cleanup of Marlin (#6359)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
9da4aad44b
commit
babf52dade
@ -3,7 +3,7 @@
|
|||||||
# We use this for fp8, which HF does not support.
|
# We use this for fp8, which HF does not support.
|
||||||
#
|
#
|
||||||
# Make sure you have lm-eval-harness installed:
|
# Make sure you have lm-eval-harness installed:
|
||||||
# pip install lm-eval==0.4.2
|
# pip install lm-eval==0.4.3
|
||||||
|
|
||||||
usage() {
|
usage() {
|
||||||
echo``
|
echo``
|
||||||
|
|||||||
@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||||
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
|
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||||
|
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
||||||
verify_marlin_supported, verify_marlin_supports_shape)
|
verify_marlin_supported, verify_marlin_supports_shape)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
|
|
||||||
@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
del output_size
|
del output_size
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
is_row_parallel = input_size != input_size_per_partition
|
||||||
|
|
||||||
# Normalize group_size
|
# Normalize group_size
|
||||||
if self.quant_config.group_size != -1:
|
if self.quant_config.group_size != -1:
|
||||||
@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
group_size=group_size)
|
group_size=group_size)
|
||||||
|
|
||||||
# Detect sharding of scales/zp
|
# Determine sharding
|
||||||
|
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||||
# By default, no sharding over "input dim"
|
self.quant_config.group_size,
|
||||||
scales_and_zp_size = input_size // group_size
|
is_row_parallel):
|
||||||
|
# By setting scale_dim == None, weight_loader will
|
||||||
|
# repeat the scales on each GPU in TP>1 case.
|
||||||
scales_and_zp_input_dim = None
|
scales_and_zp_input_dim = None
|
||||||
|
scales_and_zp_size = input_size // group_size
|
||||||
if self.quant_config.desc_act:
|
|
||||||
# Act-order case
|
|
||||||
assert self.quant_config.group_size != -1
|
|
||||||
|
|
||||||
is_k_full = input_size_per_partition == input_size
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# No act-order case
|
# By setting scale_dim == 0, weight_loader will
|
||||||
|
# shard the scales in TP>1 case.
|
||||||
# K is always full due to full alignment with
|
|
||||||
# group-size and shard of scales/zp
|
|
||||||
is_k_full = True
|
|
||||||
|
|
||||||
# If this is a row-parallel case, then shard scales/zp
|
|
||||||
if (input_size != input_size_per_partition
|
|
||||||
and self.quant_config.group_size != -1):
|
|
||||||
scales_and_zp_size = input_size_per_partition // group_size
|
|
||||||
scales_and_zp_input_dim = 0
|
scales_and_zp_input_dim = 0
|
||||||
|
scales_and_zp_size = input_size_per_partition // group_size
|
||||||
# Init buffers
|
|
||||||
|
|
||||||
# Quantized weights
|
# Quantized weights
|
||||||
qweight = Parameter(
|
qweight = Parameter(
|
||||||
@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
layer.input_size_per_partition = input_size_per_partition
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
layer.input_size = input_size
|
layer.input_size = input_size
|
||||||
layer.is_k_full = is_k_full
|
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
|
||||||
|
is_row_parallel)
|
||||||
|
|
||||||
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
||||||
# marlin format. This function is called after the weights are loaded.
|
# marlin format. This function is called after the weights are loaded.
|
||||||
# Here, we handle the repacking, including the activation reordering case.
|
# Here, we handle the repacking, including the activation reordering case.
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
device = layer.qweight.device
|
device = layer.qweight.device
|
||||||
|
|
||||||
# Allocate marlin workspace
|
# Allocate marlin workspace
|
||||||
layer.workspace = marlin_make_workspace(
|
layer.workspace = marlin_make_workspace(
|
||||||
layer.output_size_per_partition, device)
|
layer.output_size_per_partition, device)
|
||||||
@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
return apply_marlin_linear(
|
||||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
input=x,
|
||||||
|
weight=layer.qweight,
|
||||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
weight_scale=layer.scales,
|
||||||
layer.qweight,
|
|
||||||
layer.scales,
|
|
||||||
g_idx=layer.g_idx,
|
g_idx=layer.g_idx,
|
||||||
perm=layer.g_idx_sort_indices,
|
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||||
workspace=layer.workspace,
|
workspace=layer.workspace,
|
||||||
num_bits=self.quant_config.weight_bits,
|
num_bits=self.quant_config.weight_bits,
|
||||||
size_m=reshaped_x.shape[0],
|
output_size_per_partition=layer.output_size_per_partition,
|
||||||
size_n=layer.output_size_per_partition,
|
input_size_per_partition=layer.input_size_per_partition,
|
||||||
size_k=layer.input_size_per_partition,
|
is_k_full=layer.is_k_full,
|
||||||
is_k_full=layer.is_k_full)
|
bias=bias)
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
output.add_(bias) # In-place add
|
|
||||||
|
|
||||||
return output.reshape(out_shape)
|
|
||||||
|
|||||||
@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||||
|
return (not act_order) or (act_order and not is_row_parallel)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
|
||||||
|
is_row_parallel: bool) -> bool:
|
||||||
|
# Need to repeat scales on every rank if act_ordering or
|
||||||
|
# channelwise and RowParallelLinear
|
||||||
|
is_channelwise = group_size == -1
|
||||||
|
return act_order or (is_channelwise and is_row_parallel)
|
||||||
|
|
||||||
|
|
||||||
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user