diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml new file mode 100644 index 00000000..43ff2bc5 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1 +model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.593 + - name: "exact_match,flexible-extract" + value: 0.588 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml new file mode 100644 index 00000000..259799ba --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-W8A16-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1 +model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.595 + - name: "exact_match,flexible-extract" + value: 0.582 +limit: 1000 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 3300ca64..3d1306f6 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -2,3 +2,4 @@ Meta-Llama-3-8B-Instruct.yaml Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index 933733e9..d68c6993 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do done lm_eval --model vllm \ - --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \ + --model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \ --tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \ --batch_size $BATCH_SIZE diff --git a/tests/models/test_compressed_tensors.py b/tests/models/test_compressed_tensors.py index 9a0054c5..da47d5f3 100644 --- a/tests/models/test_compressed_tensors.py +++ b/tests/models/test_compressed_tensors.py @@ -12,7 +12,10 @@ from tests.quantization.utils import is_quant_method_supported from .utils import check_logprobs_close MODELS = [ + # No bias "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test", + # Bias + "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8" ] MAX_TOKENS = 32 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index c711fd14..524b4c89 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -267,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): """ - if bias is not None: - raise ValueError("bias is not supported for this linear method") - scheme = layer.scheme if scheme is None: raise ValueError("A scheme must be defined for each layer") - return scheme.apply_weights(layer, x) + return scheme.apply_weights(layer, x, bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py index 119f6cd9..3aa91307 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import Optional import torch @@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC): raise NotImplementedError @abstractmethod - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): """ Run the forward pass for the particular scheme. This is where scheme-specific dequant/quant steps/kernels should be applied. - :param layer: toch.nn.Module with the registered weights and + :param layer: torch.nn.Module with the registered weights and other parameters relevant to the particular scheme. :param x: input to the layer + :param bias: bias parameter """ raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py index f5911bc3..2c7fe3e0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, List, Optional import torch import torch.nn.functional as F @@ -37,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): layer.register_parameter("weight", weight) set_weight_attrs(weight, {"weight_loader": weight_loader}) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): - weight = layer.weight - return F.linear(x, weight) + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + return F.linear(x, layer.weight, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py index 3c07d6b6..54bf85c0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -118,7 +118,9 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): requires_grad=False) layer.workspace = workspace - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + qweight = layer.weight_packed meta = layer.meta scales = layer.scale_packed @@ -135,4 +137,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): size_n, size_k) output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index e70504ec..6fec5d01 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -1,4 +1,4 @@ -from typing import Callable, List +from typing import Callable, List, Optional import torch from torch.nn import Parameter @@ -78,8 +78,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): **layer_kwargs) layer.register_parameter("input_scale", scale) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return apply_int8_linear(input=x, weight=layer.weight, weight_scale=layer.weight_scale, - input_scale=layer.input_scale) + input_scale=layer.input_scale, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index ed9fa73c..187a3f98 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -148,7 +148,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): group_size=layer.group_size) replace_tensor(layer, "weight_scale", marlin_scales) - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor): + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return apply_marlin_linear( input=x, weight=layer.weight_packed, @@ -159,4 +161,5 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): num_bits=self.num_bits, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, - is_k_full=True) + is_k_full=True, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 81b7fdb7..30a82e1b 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -148,9 +148,6 @@ def apply_int8_linear( input_scale: torch.Tensor, bias: Optional[torch.Tensor] = None, ): - if bias is not None: - raise NotImplementedError("W8A8 with int8 does not yet support bias.") - # ops.scaled_int8_quant supports both dynamic and static quant. # * dynamic, layer.input_scale is None and x_scale computed from x. # * static, layer.input_scale is scalar and x_scale is input_scale. @@ -160,4 +157,5 @@ def apply_int8_linear( weight, scale_a=x_scale, scale_b=weight_scale, - out_dtype=input.dtype) + out_dtype=input.dtype, + bias=bias)