[Bugfix] Avoid Warnings in SparseML Activation Quantization (#5120)

This commit is contained in:
Robert Shaw 2024-05-30 17:04:37 -07:00 committed by GitHub
parent 45a1a69b98
commit b35be5403f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -89,23 +89,34 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {
"weight_loader": weight_loader,
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {"weight_loader": weight_loader})
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {"weight_loader": weight_loader})
set_weight_attrs(input_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
set_weight_attrs(weight_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight