[Bugfix] Avoid Warnings in SparseML Activation Quantization (#5120)
This commit is contained in:
parent
45a1a69b98
commit
b35be5403f
@ -89,23 +89,34 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||
|
||||
set_weight_attrs(weight, {
|
||||
"weight_loader": weight_loader,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
set_weight_attrs(input_scale, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(input_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
set_weight_attrs(input_zero_point, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(input_zero_point, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
"weight_loader": weight_loader,
|
||||
"shard_splitter": self.scales_shard_splitter,
|
||||
"logical_widths": output_partition_sizes
|
||||
"logical_widths": output_partition_sizes,
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(weight_zero_point, {
|
||||
"weight_loader": weight_loader,
|
||||
"ignore_warning": True
|
||||
})
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
|
||||
Loading…
Reference in New Issue
Block a user