[Bugfix] Fix empty (nullptr) channelwise scales when loading wNa16 using compressed tensors (#6798)

This commit is contained in:
Lucas Wilkinson 2024-07-25 18:05:09 -04:00 committed by GitHub
parent 6a1e25b151
commit cd7edc4e87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -55,7 +55,12 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size
channelwise = (self.group_size == -1)
group_size = input_size if channelwise else self.group_size
row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the
# scales across all gpus.
partition_scales = (row_parallel and not channelwise)
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
@ -66,8 +71,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_scale_dim = None
scales_and_zp_size = input_size // group_size
if (input_size != input_size_per_partition
and self.group_size is not None):
if partition_scales:
assert input_size_per_partition % group_size == 0
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size