[ Misc ] Support Models With Bias in compressed-tensors integration (#6356)
This commit is contained in:
parent
f7160d946a
commit
aea19f0989
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8 -b "auto" -l 1000 -f 5 -t 1
|
||||||
|
model_name: "neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.593
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.588
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise -b "auto" -l 1000 -f 5 -t 1
|
||||||
|
model_name: "nm-testing/Qwen2-1.5B-Instruct-W8A16-Channelwise"
|
||||||
|
tasks:
|
||||||
|
- name: "gsm8k"
|
||||||
|
metrics:
|
||||||
|
- name: "exact_match,strict-match"
|
||||||
|
value: 0.595
|
||||||
|
- name: "exact_match,flexible-extract"
|
||||||
|
value: 0.582
|
||||||
|
limit: 1000
|
||||||
|
num_fewshot: 5
|
||||||
@ -2,3 +2,4 @@ Meta-Llama-3-8B-Instruct.yaml
|
|||||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||||
|
|||||||
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
|||||||
done
|
done
|
||||||
|
|
||||||
lm_eval --model vllm \
|
lm_eval --model vllm \
|
||||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true \
|
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray" \
|
||||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||||
--batch_size $BATCH_SIZE
|
--batch_size $BATCH_SIZE
|
||||||
|
|||||||
@ -12,7 +12,10 @@ from tests.quantization.utils import is_quant_method_supported
|
|||||||
from .utils import check_logprobs_close
|
from .utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
|
# No bias
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test",
|
"nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Per-Token-Test",
|
||||||
|
# Bias
|
||||||
|
"neuralmagic/Qwen2-1.5B-Instruct-quantized.w8a8"
|
||||||
]
|
]
|
||||||
|
|
||||||
MAX_TOKENS = 32
|
MAX_TOKENS = 32
|
||||||
|
|||||||
@ -267,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
raise ValueError("bias is not supported for this linear method")
|
|
||||||
|
|
||||||
scheme = layer.scheme
|
scheme = layer.scheme
|
||||||
if scheme is None:
|
if scheme is None:
|
||||||
raise ValueError("A scheme must be defined for each layer")
|
raise ValueError("A scheme must be defined for each layer")
|
||||||
return scheme.apply_weights(layer, x)
|
return scheme.apply_weights(layer, x, bias=bias)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -20,14 +21,16 @@ class CompressedTensorsScheme(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]):
|
||||||
"""
|
"""
|
||||||
Run the forward pass for the particular scheme. This is where
|
Run the forward pass for the particular scheme. This is where
|
||||||
scheme-specific dequant/quant steps/kernels should be applied.
|
scheme-specific dequant/quant steps/kernels should be applied.
|
||||||
|
|
||||||
:param layer: toch.nn.Module with the registered weights and
|
:param layer: torch.nn.Module with the registered weights and
|
||||||
other parameters relevant to the particular scheme.
|
other parameters relevant to the particular scheme.
|
||||||
:param x: input to the layer
|
:param x: input to the layer
|
||||||
|
:param bias: bias parameter
|
||||||
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -37,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
|
|||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
weight = layer.weight
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
return F.linear(x, weight)
|
|
||||||
|
return F.linear(x, layer.weight, bias)
|
||||||
|
|||||||
@ -118,7 +118,9 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
|||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.workspace = workspace
|
layer.workspace = workspace
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
||||||
qweight = layer.weight_packed
|
qweight = layer.weight_packed
|
||||||
meta = layer.meta
|
meta = layer.meta
|
||||||
scales = layer.scale_packed
|
scales = layer.scale_packed
|
||||||
@ -135,4 +137,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
|||||||
size_n, size_k)
|
size_n, size_k)
|
||||||
|
|
||||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
@ -78,8 +78,11 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
|||||||
**layer_kwargs)
|
**layer_kwargs)
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
||||||
return apply_int8_linear(input=x,
|
return apply_int8_linear(input=x,
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
input_scale=layer.input_scale)
|
input_scale=layer.input_scale,
|
||||||
|
bias=bias)
|
||||||
|
|||||||
@ -148,7 +148,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
group_size=layer.group_size)
|
group_size=layer.group_size)
|
||||||
replace_tensor(layer, "weight_scale", marlin_scales)
|
replace_tensor(layer, "weight_scale", marlin_scales)
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
|
||||||
return apply_marlin_linear(
|
return apply_marlin_linear(
|
||||||
input=x,
|
input=x,
|
||||||
weight=layer.weight_packed,
|
weight=layer.weight_packed,
|
||||||
@ -159,4 +161,5 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
|||||||
num_bits=self.num_bits,
|
num_bits=self.num_bits,
|
||||||
output_size_per_partition=layer.output_size_per_partition,
|
output_size_per_partition=layer.output_size_per_partition,
|
||||||
input_size_per_partition=layer.input_size_per_partition,
|
input_size_per_partition=layer.input_size_per_partition,
|
||||||
is_k_full=True)
|
is_k_full=True,
|
||||||
|
bias=bias)
|
||||||
|
|||||||
@ -148,9 +148,6 @@ def apply_int8_linear(
|
|||||||
input_scale: torch.Tensor,
|
input_scale: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if bias is not None:
|
|
||||||
raise NotImplementedError("W8A8 with int8 does not yet support bias.")
|
|
||||||
|
|
||||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||||
@ -160,4 +157,5 @@ def apply_int8_linear(
|
|||||||
weight,
|
weight,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
out_dtype=input.dtype)
|
out_dtype=input.dtype,
|
||||||
|
bias=bias)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user