[misc] update tpu int8 to use new vLLM Parameters (#7973)
This commit is contained in:
parent
d78789ac16
commit
86a677de42
@ -23,7 +23,8 @@ logger = init_logger(__name__)
|
|||||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod"
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||||
|
"TPUInt8LinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["none"]
|
ACTIVATION_SCHEMES = ["none"]
|
||||||
|
|
||||||
@ -64,16 +64,16 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
|||||||
output_partition_sizes: List[int], input_size: int,
|
output_partition_sizes: List[int], input_size: int,
|
||||||
output_size: int, params_dtype: torch.dtype,
|
output_size: int, params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
input_size_per_partition,
|
input_size_per_partition,
|
||||||
dtype=params_dtype),
|
dtype=params_dtype),
|
||||||
requires_grad=False)
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight", weight)
|
layer.register_parameter("weight", weight)
|
||||||
set_weight_attrs(weight, {
|
|
||||||
**extra_weight_attrs,
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
def _quantize_weight(
|
def _quantize_weight(
|
||||||
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -92,6 +92,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
|
|||||||
return qweight, qscale
|
return qweight, qscale
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
||||||
device = layer.weight.device
|
device = layer.weight.device
|
||||||
qweight, qscale = self._quantize_weight(layer.weight)
|
qweight, qscale = self._quantize_weight(layer.weight)
|
||||||
qweight = qweight.to(device)
|
qweight = qweight.to(device)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user