[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Robert Shaw 2024-06-28 14:43:49 -04:00 committed by GitHub
parent b185230744
commit 2cd402e169
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 22 deletions

View File

@ -383,8 +383,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
@ -567,8 +572,13 @@ class QKVParallelLinear(ColumnParallelLinear):
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
# Loaded weight is already fused on disk (qkv/mlp).
if output_dim is None:
# If fp8 + scale, need to send to each shard.
if fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return

View File

@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
self.fused_module_in_checkpoint = False
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
@ -111,6 +112,7 @@ class Fp8LinearMethod(LinearMethodBase):
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float8_e4m3fn).min
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
@ -169,11 +171,15 @@ class Fp8LinearMethod(LinearMethodBase):
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Optional[Union[str,
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, int):
if shard_id is None:
shard_id = 0
self.fused_module_in_checkpoint = True
elif isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
@ -205,15 +211,17 @@ class Fp8LinearMethod(LinearMethodBase):
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
if not self.fused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(
layer.weight[start:end, :], layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT
@ -227,10 +235,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "dynamic":
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
@ -317,11 +321,6 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)