From fb377d7e74228d477e270d8a8e53410db29ed755 Mon Sep 17 00:00:00 2001 From: Dipika Sikka Date: Tue, 13 Aug 2024 14:30:11 -0400 Subject: [PATCH] [Misc] Update `gptq_marlin` to use new vLLMParameters (#7281) --- .buildkite/test-pipeline.yaml | 10 ++ tests/weight_loading/models.txt | 15 +++ .../run_model_weight_loading_test.sh | 32 +++++ tests/weight_loading/test_weight_loading.py | 20 +++ vllm/model_executor/layers/linear.py | 4 +- .../schemes/compressed_tensors_wNa16.py | 2 +- .../layers/quantization/gptq_marlin.py | 125 +++++++++--------- vllm/model_executor/parameter.py | 124 ++++++++++++----- 8 files changed, 234 insertions(+), 98 deletions(-) create mode 100644 tests/weight_loading/models.txt create mode 100644 tests/weight_loading/run_model_weight_loading_test.sh create mode 100644 tests/weight_loading/test_weight_loading.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e21ae6b0..9b9d4645 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -314,6 +314,16 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s -x lora/test_long_context.py +- label: Weight Loading Multiple GPU Test + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh + + ##### multi gpus test ##### ##### A100 test ##### diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt new file mode 100644 index 00000000..84ca8bcb --- /dev/null +++ b/tests/weight_loading/models.txt @@ -0,0 +1,15 @@ +gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main +gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main +gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main +gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True +gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True +gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main +compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main +compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main +compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main +compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main +compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main +compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main +compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main +compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main +compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main \ No newline at end of file diff --git a/tests/weight_loading/run_model_weight_loading_test.sh b/tests/weight_loading/run_model_weight_loading_test.sh new file mode 100644 index 00000000..0cb45d17 --- /dev/null +++ b/tests/weight_loading/run_model_weight_loading_test.sh @@ -0,0 +1,32 @@ +#!/bin/bash +SUCCESS=0 + +IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "weight_loading/models.txt" + +for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" +do + LOCAL_SUCCESS=0 + IFS=', ' read -r -a array <<< "$MODEL_CONFIG" + + echo "=== RUNNING MODEL: $MODEL_CONFIG ===" + + export QUANTIZATION=${array[0]} + export MODEL_NAME=${array[1]} + export REVISION=${array[2]} + pytest -s weight_loading/test_weight_loading.py || LOCAL_SUCCESS=$? + + if [[ $LOCAL_SUCCESS == 0 ]]; then + echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" + else + echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" + fi + + SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) + +done + +if [ "${SUCCESS}" -eq "0" ]; then + exit 0 +else + exit 1 +fi diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py new file mode 100644 index 00000000..c13313df --- /dev/null +++ b/tests/weight_loading/test_weight_loading.py @@ -0,0 +1,20 @@ +import os + +MAX_MODEL_LEN = 1024 +MODEL_NAME = os.environ.get("MODEL_NAME", + "robertgshaw2/zephyr-7b-beta-channelwise-gptq") +REVISION = os.environ.get("REVISION", "main") +QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin") + + +def test_weight_loading(vllm_runner): + with vllm_runner(model_name=MODEL_NAME, + revision=REVISION, + dtype="auto", + quantization=QUANTIZATION, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=2) as model: + + output = model.generate_greedy("Hello world!", max_tokens=20) + print(output) + assert output diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e574062e..cececea1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) -WEIGHT_LOADER_V2_SUPPORTED = ["CompressedTensorsLinearMethod"] +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod" +] def adjust_marlin_shard(param, shard_size, shard_offset): diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 94699c27..7ca8eecb 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): dtype=params_dtype, ) } - if self.group_size == -1: + if not partition_scales: weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args) else: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index b9269753..94eb3f30 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,12 +1,11 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter +from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( @@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: + del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") # Normalize group_size if self.quant_config.group_size != -1: @@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase): scales_and_zp_size = input_size_per_partition // group_size # Quantized weights - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - **extra_weight_attrs, - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }, - ) + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) # Activation order - g_idx = Parameter( - torch.empty( - input_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs( - g_idx, - { - **extra_weight_attrs, "input_dim": 0, - "ignore_warning": True - }, - ) + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) - # Scales - scales = Parameter( - torch.empty( - scales_and_zp_size, - output_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim, - "output_dim": 1, - }, - ) - - # Quantized zero-points - qzeros = Parameter( + qzeros_args = { + "data": torch.empty( scales_and_zp_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, - { - **extra_weight_attrs, - "input_dim": scales_and_zp_input_dim, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }, - ) + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) @@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device + # required by torch.compile + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + # Allocate marlin workspace layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 10239843..c6cfab78 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -9,7 +9,7 @@ from vllm.logger import init_logger __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", "ModelWeightParameter", "ChannelQuantScaleParameter", - "GroupQuantScaleParameter" + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" ] logger = init_logger(__name__) @@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_size = kwargs.get("shard_size") if isinstance( self, - PackedvLLMParameter) and self.packed_dim == self.output_dim: + (PackedColumnParameter, + PackedvLLMParameter)) and self.packed_dim == self.output_dim: shard_size, shard_offset = self.adjust_shard_indexes_for_packing( shard_offset=shard_offset, shard_size=shard_size) @@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter): if isinstance( self, - PackedvLLMParameter) and self.output_dim == self.packed_dim: + (PackedColumnParameter, + PackedvLLMParameter)) and self.output_dim == self.packed_dim: shard_size, shard_offset = self.adjust_shard_indexes_for_packing( shard_offset=shard_offset, shard_size=shard_size) @@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data.copy_(loaded_weight) -class ModelWeightParameter(_ColumnvLLMParameter): +class RowvLLMParameter(BasevLLMParameter): """ - Parameter class for linear layer weights. Extends the - _ColumnvLLMParameter by adding loading functionality - for linear layers with row parallel functionality. - Requires an input dimension to be defined. + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. """ def __init__(self, input_dim: int, **kwargs): @@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter): self.data.copy_(loaded_weight) -class GroupQuantScaleParameter(ModelWeightParameter): +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): """ Parameter class for weight scales loaded for weights with - grouped quantization. Equivalent to ModelWeightParameter. + grouped quantization. Uses both column and row parallelism. """ pass @@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter): class ChannelQuantScaleParameter(_ColumnvLLMParameter): """ Parameter class for weight scales loaded for weights with - channel-wise quantization. Equivalent to _ColumnvLLMParameter. + channel-wise quantization. Equivalent to _ColumnvLLMParameter. """ pass @@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter): Parameter class for scales where the number of scales is equivalent to the number of logical matrices in fused linear layers (e.g. for QKV, there are 3 scales loaded from disk). - This is relevant to weights with per-tensor quantization. + This is relevant to weights with per-tensor quantization. Adds functionality to map the scalers to a shard during weight loading. @@ -232,15 +242,11 @@ class PerTensorScaleParameter(BasevLLMParameter): param_data.copy_(loaded_weight) -class PackedvLLMParameter(ModelWeightParameter): +class PackedColumnParameter(_ColumnvLLMParameter): """ - Parameter for model weights which are packed on disk. - Example: GPTQ Marlin weights are int4 or int8, packed into int32. - Extends the ModelWeightParameter to take in the - packed factor, the packed dimension, and optionally, marlin - tile size for marlin kernels. Adjusts the shard_size and - shard_offset for fused linear layers model weight loading - by accounting for packing and optionally, marlin tile size. + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. """ def __init__(self, @@ -250,7 +256,7 @@ class PackedvLLMParameter(ModelWeightParameter): **kwargs): self._packed_factor = packed_factor self._packed_dim = packed_dim - self._marlin_tile = marlin_tile_size + self._marlin_tile_size = marlin_tile_size super().__init__(**kwargs) @property @@ -262,16 +268,70 @@ class PackedvLLMParameter(ModelWeightParameter): return self._packed_factor @property - def marlin_tile(self): - return self._marlin_tile - - def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset): - return shard_size * self.marlin_tile, shard_offset * self.marlin_tile + def marlin_tile_size(self): + return self._marlin_tile_size def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): - shard_size = shard_size // self.packed_factor - shard_offset = shard_offset // self.packed_factor - if self.marlin_tile is not None: - return self._adjust_shard_indexes_for_marlin( - shard_size, shard_offset) - return shard_size, shard_offset + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, + packed_factor: int, + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, + marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, + marlin_tile_size): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size) + return shard_size, shard_offset