[Bugfix]Fix Phi-3 BNB online quantization (#10417)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-19 11:21:42 +08:00 committed by GitHub
parent 284203f171
commit 7eb719df13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 3 deletions

View File

@ -470,7 +470,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(
@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
@ -495,7 +498,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] // 2
shard_offset = shard_size * shard_id
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
@ -808,7 +813,8 @@ class QKVParallelLinear(ColumnParallelLinear):
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv/mlp).
# Loaded weight is already fused on disk (qkv).
# (e.g., Phi-3's qkv_proj).
if output_dim is None:
if needs_scalar_to_array:
param_data, loaded_weight = adjust_scalar_to_fused_array(

View File

@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM):
"gate_up_proj",
],
}
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_up_proj.",
".down_proj.",
".qkv_proj.",
".o_proj.",
]
# Initialize an empty dict when there is no stacked parameter mapping.
bitsandbytes_stacked_params_mapping = {}