[ Bugfix ] Enabling Loading Models With Fused QKV/MLP on Disk with FP8 (#5921)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
parent
b185230744
commit
2cd402e169
@ -383,8 +383,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
@ -567,8 +572,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
# If fp8 + scale, need to send to each shard.
|
||||
if fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
@ -98,6 +98,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.fused_module_in_checkpoint = False
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
@ -111,6 +112,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
scale = Parameter(torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(
|
||||
scale, {
|
||||
@ -169,11 +171,15 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs)
|
||||
|
||||
def scales_shard_indexer(
|
||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||
shard_id: Optional[Union[str,
|
||||
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
if isinstance(shard_id, int):
|
||||
if shard_id is None:
|
||||
shard_id = 0
|
||||
self.fused_module_in_checkpoint = True
|
||||
elif isinstance(shard_id, int):
|
||||
pass
|
||||
elif isinstance(shard_id, str):
|
||||
if shard_id not in qkv_idxs:
|
||||
@ -205,15 +211,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# WEIGHT_SCALE / WEIGHT
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
|
||||
layer.weight_scale[idx])
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max())
|
||||
start = end
|
||||
if not self.fused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
layer.weight[start:end, :], layer.weight_scale[idx])
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max())
|
||||
start = end
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# WEIGHT
|
||||
@ -227,10 +235,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.input_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
if not all_close_1d(layer.input_scale):
|
||||
raise ValueError(
|
||||
"All the input_scales for the logical weights of a "
|
||||
f"layer must be equal. But got {layer.input_scale}")
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
@ -317,11 +321,6 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor,
|
||||
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user