[Misc] Add per channel support for static activation quantization; update w8a8 schemes to share base classes (#5650)
This commit is contained in:
parent
e83db9e7e3
commit
4a30d7e3cc
@ -13,8 +13,12 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
@pytest.mark.parametrize("model_args", [
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", "tensor"),
|
||||||
|
("nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", "channel"),
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||||
|
model_path, strategy = model_args
|
||||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
@ -33,12 +37,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
|||||||
|
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
assert o_proj.weight.dtype is torch.int8
|
assert o_proj.weight.dtype is torch.int8
|
||||||
assert gate_up_proj.weight.dtype is torch.int8
|
assert gate_up_proj.weight.dtype is torch.int8
|
||||||
|
|
||||||
assert qkv_proj.weight_scale.shard_splitter is not None
|
if qkv_proj.scheme.strategy == "tensor":
|
||||||
assert qkv_proj.weight_scale.logical_widths is not None
|
assert qkv_proj.weight_scale.shard_splitter is not None
|
||||||
|
assert qkv_proj.weight_scale.logical_widths is not None
|
||||||
assert qkv_proj.input_scale.dtype is torch.float32
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -85,8 +85,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
is_tensor = (weight_quant.strategy == input_quant.strategy ==
|
weight_strategy = (
|
||||||
QuantizationStrategy.TENSOR.value)
|
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||||
|
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||||
|
is_tensor = (weight_strategy and input_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR.value)
|
||||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||||
|
|
||||||
@ -131,7 +134,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
if self.quant_format == CompressionFormat.int_quantized.value:
|
if self.quant_format == CompressionFormat.int_quantized.value:
|
||||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8StaticTensor()
|
return CompressedTensorsW8A8StaticTensor(
|
||||||
|
strategy=weight_quant.strategy)
|
||||||
|
|
||||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8DynamicToken(
|
return CompressedTensorsW8A8DynamicToken(
|
||||||
|
|||||||
@ -0,0 +1,84 @@
|
|||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationStrategy)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A8(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self, strategy: str):
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
|
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||||
|
if isinstance(shard_id, int):
|
||||||
|
return shard_id
|
||||||
|
|
||||||
|
assert isinstance(shard_id, str)
|
||||||
|
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||||
|
assert shard_id in qkv_idxs
|
||||||
|
return qkv_idxs[shard_id]
|
||||||
|
|
||||||
|
def scales_shard_splitter(
|
||||||
|
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||||
|
shard_id: Union[str, int],
|
||||||
|
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
shard_id = self._shard_id_as_int(shard_id)
|
||||||
|
offset = sum(logical_widths[:shard_id])
|
||||||
|
size = logical_widths[shard_id]
|
||||||
|
# update loaded weight with copies for broadcast.
|
||||||
|
loaded_weight = loaded_weight.repeat(size)
|
||||||
|
return param[offset:offset + size], loaded_weight
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||||
|
weight_scale_dim = sum(output_partition_sizes) if (
|
||||||
|
is_tensor_partitioned
|
||||||
|
or self.strategy == QuantizationStrategy.CHANNEL) else 1
|
||||||
|
|
||||||
|
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
shape = (weight_scale_dim, 1)
|
||||||
|
|
||||||
|
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(
|
||||||
|
weight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"weight_loader": weight_loader,
|
||||||
|
"logical_widths": output_partition_sizes
|
||||||
|
})
|
||||||
|
|
||||||
|
# Don't need a shard_splitter for channel-wise quantization
|
||||||
|
# Use the default loading method
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
set_weight_attrs(weight_scale, {
|
||||||
|
"output_dim": 0,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
set_weight_attrs(
|
||||||
|
weight_scale, {
|
||||||
|
"logical_widths": output_partition_sizes,
|
||||||
|
"shard_splitter": self.scales_shard_splitter,
|
||||||
|
})
|
||||||
@ -1,42 +1,15 @@
|
|||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from vllm import _custom_ops as custom_ops
|
from vllm import _custom_ops as custom_ops
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsW8A8)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
QuantizationStrategy)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
class CompressedTensorsW8A8DynamicToken(CompressedTensorsW8A8):
|
||||||
|
|
||||||
def __init__(self, strategy: str):
|
|
||||||
self.strategy = strategy
|
|
||||||
|
|
||||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
|
||||||
if isinstance(shard_id, int):
|
|
||||||
return shard_id
|
|
||||||
|
|
||||||
assert isinstance(shard_id, str)
|
|
||||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
||||||
assert shard_id in qkv_idxs
|
|
||||||
return qkv_idxs[shard_id]
|
|
||||||
|
|
||||||
def scales_shard_splitter(
|
|
||||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
|
||||||
shard_id: Union[str, int],
|
|
||||||
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
shard_id = self._shard_id_as_int(shard_id)
|
|
||||||
offset = sum(logical_widths[:shard_id])
|
|
||||||
size = logical_widths[shard_id]
|
|
||||||
# update loaded weight with copies for broadcast.
|
|
||||||
loaded_weight = loaded_weight.repeat(size)
|
|
||||||
return param[offset:offset + size], loaded_weight
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
@ -44,54 +17,12 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
|||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
# When the scales have a single value, it is required that they be
|
super().create_weights(
|
||||||
# on the CPU for performance and CUDA Graphs compatibility. Please
|
layer=layer,
|
||||||
# refer to the comment in
|
output_partition_sizes=output_partition_sizes,
|
||||||
# CompressedTensorsW8A8StaticTensor::create_weights for further
|
input_size_per_partition=input_size_per_partition,
|
||||||
# information.
|
params_dtype=params_dtype,
|
||||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
weight_loader=weight_loader)
|
||||||
# when doing channel-wise quantization, number of scales
|
|
||||||
# is equal to output_dim
|
|
||||||
weight_scale_dim = sum(output_partition_sizes) if (
|
|
||||||
is_tensor_partitioned
|
|
||||||
or self.strategy == QuantizationStrategy.CHANNEL) else 1
|
|
||||||
|
|
||||||
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
shape = (weight_scale_dim, 1)
|
|
||||||
|
|
||||||
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=torch.int8),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(
|
|
||||||
weight, {
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"logical_widths": output_partition_sizes
|
|
||||||
})
|
|
||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
|
||||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
|
||||||
|
|
||||||
# Don't need a shard_splitter for channel-wise quantization
|
|
||||||
# Use the default loading method
|
|
||||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
|
||||||
set_weight_attrs(weight_scale, {
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
set_weight_attrs(
|
|
||||||
weight_scale, {
|
|
||||||
"logical_widths": output_partition_sizes,
|
|
||||||
"shard_splitter": self.scales_shard_splitter,
|
|
||||||
})
|
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|||||||
@ -1,37 +1,17 @@
|
|||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as custom_ops
|
from vllm import _custom_ops as custom_ops
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import ( # noqa: E501
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsW8A8)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A8StaticTensor"]
|
__all__ = ["CompressedTensorsW8A8StaticTensor"]
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
|
||||||
|
|
||||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
|
||||||
if isinstance(shard_id, int):
|
|
||||||
return shard_id
|
|
||||||
|
|
||||||
assert isinstance(shard_id, str)
|
|
||||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
||||||
assert shard_id in qkv_idxs
|
|
||||||
return qkv_idxs[shard_id]
|
|
||||||
|
|
||||||
def scales_shard_splitter(
|
|
||||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
|
||||||
shard_id: Union[str, int],
|
|
||||||
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
shard_id = self._shard_id_as_int(shard_id)
|
|
||||||
offset = sum(logical_widths[:shard_id])
|
|
||||||
size = logical_widths[shard_id]
|
|
||||||
# update loaded weight with copies for broadcast.
|
|
||||||
loaded_weight = loaded_weight.repeat(size)
|
|
||||||
return param[offset:offset + size], loaded_weight
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
@ -39,41 +19,21 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
super().create_weights(
|
||||||
weight_scale_dim = sum(
|
layer=layer,
|
||||||
output_partition_sizes) if is_tensor_partitioned else 1
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
input_size_per_partition=input_size_per_partition,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
|
||||||
dtype=torch.float32),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=torch.int8),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, {
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
set_weight_attrs(input_scale, {
|
set_weight_attrs(input_scale, {
|
||||||
"weight_loader": weight_loader,
|
"weight_loader": weight_loader,
|
||||||
"ignore_warning": True,
|
"ignore_warning": True,
|
||||||
})
|
})
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
|
||||||
set_weight_attrs(
|
|
||||||
weight_scale, {
|
|
||||||
"weight_loader": weight_loader,
|
|
||||||
"shard_splitter": self.scales_shard_splitter,
|
|
||||||
"logical_widths": output_partition_sizes,
|
|
||||||
"ignore_warning": True,
|
|
||||||
})
|
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user