[misc] update tpu int8 to use new vLLM Parameters (#7973)
This commit is contained in:
parent
d78789ac16
commit
86a677de42
@ -23,7 +23,8 @@ logger = init_logger(__name__)
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
|
||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||
"TPUInt8LinearMethod"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
|
||||
ACTIVATION_SCHEMES = ["none"]
|
||||
|
||||
@ -64,16 +64,16 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
|
||||
def _quantize_weight(
|
||||
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -92,6 +92,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
||||
return qweight, qscale
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||
device = layer.weight.device
|
||||
qweight, qscale = self._quantize_weight(layer.weight)
|
||||
qweight = qweight.to(device)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user