[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (#5542)
This commit is contained in:
parent
7879f24dcc
commit
95db455e7f
@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
|
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
||||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
@ -43,15 +43,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
|||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8a8-static-v2"
|
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path) as llm:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
@pytest.mark.parametrize("model_args", [
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
|
("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
|
||||||
|
("nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", "channel"),
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
||||||
|
model_path, strategy = model_args
|
||||||
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
@ -60,6 +64,7 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
|||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
||||||
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -468,13 +468,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
"MergedColumnParallelLinear, assume the weight is "
|
"MergedColumnParallelLinear, assume the weight is "
|
||||||
"the same for all partitions.")
|
"the same for all partitions.")
|
||||||
|
|
||||||
if fp8_scales_shard_indexer is None:
|
|
||||||
if len(param_data.shape) == 0:
|
|
||||||
param_data = param_data.reshape(1)
|
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
@ -686,12 +679,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
"QKVParallelLinear, assume the weight is the same "
|
"QKVParallelLinear, assume the weight is the same "
|
||||||
"for all partitions.")
|
"for all partitions.")
|
||||||
|
|
||||||
if len(param_data.shape) == 0:
|
|
||||||
param_data = param_data.reshape(1)
|
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
|
|||||||
@ -95,14 +95,15 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
is_token_tensor = (weight_quant.strategy
|
weight_strategy = (
|
||||||
== QuantizationStrategy.TENSOR.value) and (
|
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||||
input_quant.strategy
|
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
|
||||||
|
is_token = (weight_strategy and input_quant.strategy
|
||||||
== QuantizationStrategy.TOKEN.value)
|
== QuantizationStrategy.TOKEN.value)
|
||||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||||
|
|
||||||
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
|
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||||
|
|
||||||
def _is_w4a16(self, weight_quant: BaseModel,
|
def _is_w4a16(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> bool:
|
input_quant: BaseModel) -> bool:
|
||||||
@ -133,7 +134,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
return CompressedTensorsW8A8StaticTensor()
|
return CompressedTensorsW8A8StaticTensor()
|
||||||
|
|
||||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8DynamicToken()
|
return CompressedTensorsW8A8DynamicToken(
|
||||||
|
strategy=weight_quant.strategy)
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"No compressed-tensors compatible scheme was found.")
|
"No compressed-tensors compatible scheme was found.")
|
||||||
|
|||||||
@ -6,6 +6,8 @@ from torch.nn import Parameter
|
|||||||
from vllm import _custom_ops as custom_ops
|
from vllm import _custom_ops as custom_ops
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationStrategy)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
||||||
@ -13,6 +15,9 @@ __all__ = ["CompressedTensorsW8A8DynamicToken"]
|
|||||||
|
|
||||||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self, strategy: str):
|
||||||
|
self.strategy = strategy
|
||||||
|
|
||||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||||
if isinstance(shard_id, int):
|
if isinstance(shard_id, int):
|
||||||
return shard_id
|
return shard_id
|
||||||
@ -45,11 +50,17 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
|||||||
# CompressedTensorsW8A8StaticTensor::create_weights for further
|
# CompressedTensorsW8A8StaticTensor::create_weights for further
|
||||||
# information.
|
# information.
|
||||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||||
weight_scale_dim = sum(
|
# when doing channel-wise quantization, number of scales
|
||||||
output_partition_sizes) if is_tensor_partitioned else 1
|
# is equal to output_dim
|
||||||
|
weight_scale_dim = sum(output_partition_sizes) if (
|
||||||
|
is_tensor_partitioned
|
||||||
|
or self.strategy == QuantizationStrategy.CHANNEL) else 1
|
||||||
|
|
||||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
|
||||||
dtype=torch.float32),
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
shape = (weight_scale_dim, 1)
|
||||||
|
|
||||||
|
weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
@ -67,11 +78,19 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
|||||||
})
|
})
|
||||||
|
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
# Don't need a shard_splitter for channel-wise quantization
|
||||||
|
# Use the default loading method
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
set_weight_attrs(weight_scale, {
|
||||||
|
"output_dim": 0,
|
||||||
|
})
|
||||||
|
else:
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
weight_scale, {
|
weight_scale, {
|
||||||
"weight_loader": weight_loader,
|
"logical_widths": output_partition_sizes,
|
||||||
"shard_splitter": self.scales_shard_splitter,
|
"shard_splitter": self.scales_shard_splitter,
|
||||||
"logical_widths": output_partition_sizes
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user