[ Misc ] Support Models With Bias in compressed-tensors integration (#6356)

This commit is contained in:
Robert Shaw 2024-07-12 11:11:29 -04:00 committed by GitHub
parent f7160d946a
commit aea19f0989
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 58 additions and 21 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.593
- name: "exact_match,flexible-extract"
value: 0.588
limit: 1000
num_fewshot: 5

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.595
- name: "exact_match,flexible-extract"
value: 0.582
limit: 1000
num_fewshot: 5

View File

@ -2,3 +2,4 @@ Meta-Llama-3-8B-Instruct.yaml
Meta-Llama-3-8B-Instruct-FP8.yaml
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml

View File

@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
done
lm_eval --model vllm \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
--batch_size $BATCH_SIZE

View File

@ -12,7 +12,10 @@ from tests.quantization.utils import is_quant_method_supported
from .utils import check_logprobs_close
MODELS = [
# No bias
"nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test",
# Bias
"neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
]
MAX_TOKENS = 32

View File

@ -267,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
"""
if bias is not None:
raise ValueError("bias is not supported for this linear method")
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x)
return scheme.apply_weights(layer, x, bias=bias)

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC):
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError

View File

@ -1,4 +1,4 @@
from typing import Callable, List
from typing import Callable, List, Optional
import torch
import torch.nn.functional as F
@ -37,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
return F.linear(x, weight)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return F.linear(x, layer.weight, bias)

View File

@ -118,7 +118,9 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
requires_grad=False)
layer.workspace = workspace
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
@ -135,4 +137,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@ -1,4 +1,4 @@
from typing import Callable, List
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
@ -78,8 +78,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
**layer_kwargs)
layer.register_parameter("input_scale", scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
input_scale=layer.input_scale,
bias=bias)

View File

@ -148,7 +148,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_marlin_linear(
input=x,
weight=layer.weight_packed,
@ -159,4 +161,5 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
num_bits=self.num_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True)
is_k_full=True,
bias=bias)

View File

@ -148,9 +148,6 @@ def apply_int8_linear(
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
if bias is not None:
raise NotImplementedError("W8A8 with int8 does not yet support bias.")
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
@ -160,4 +157,5 @@ def apply_int8_linear(
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype)
out_dtype=input.dtype,
bias=bias)