From babf52dade78ff3b1bea6cb6e9f4151dfd630251 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Sat, 13 Jul 2024 06:21:37 -0400 Subject: [PATCH] [ Misc ] More Cleanup of Marlin (#6359) Co-authored-by: Robert Shaw --- .../run-lm-eval-gsm-vllm-baseline.sh | 2 +- .../layers/quantization/gptq_marlin.py | 78 ++++++++----------- .../layers/quantization/utils/marlin_utils.py | 12 +++ 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index d68c6993..1bddbd89 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.2 +# pip install lm-eval==0.4.3 usage() { echo`` diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 7b808f52..07a73d06 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, marlin_sort_g_idx, replace_tensor, + apply_marlin_linear, check_marlin_supported, marlin_is_k_full, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ) -> None: del output_size output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition # Normalize group_size if self.quant_config.group_size != -1: @@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase): input_size=input_size, group_size=group_size) - # Detect sharding of scales/zp - - # By default, no sharding over "input dim" - scales_and_zp_size = input_size // group_size - scales_and_zp_input_dim = None - - if self.quant_config.desc_act: - # Act-order case - assert self.quant_config.group_size != -1 - - is_k_full = input_size_per_partition == input_size - + # Determine sharding + if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size else: - # No act-order case - - # K is always full due to full alignment with - # group-size and shard of scales/zp - is_k_full = True - - # If this is a row-parallel case, then shard scales/zp - if (input_size != input_size_per_partition - and self.quant_config.group_size != -1): - scales_and_zp_size = input_size_per_partition // group_size - scales_and_zp_input_dim = 0 - - # Init buffers + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size # Quantized weights qweight = Parameter( @@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size - layer.is_k_full = is_k_full + layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, + is_row_parallel) # Checkpoints are serialized in AutoGPTQ format, which is different from the # marlin format. This function is called after the weights are loaded. # Here, we handle the repacking, including the activation reordering case. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device + # Allocate marlin workspace layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) @@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (layer.output_size_per_partition, ) - - output = ops.gptq_marlin_gemm(reshaped_x, - layer.qweight, - layer.scales, - g_idx=layer.g_idx, - perm=layer.g_idx_sort_indices, - workspace=layer.workspace, - num_bits=self.quant_config.weight_bits, - size_m=reshaped_x.shape[0], - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - is_k_full=layer.is_k_full) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) + return apply_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + num_bits=self.quant_config.weight_bits, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + is_k_full=layer.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 612c5fd2..764f0a6f 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int, requires_grad=False) +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), requires_grad=False)