[Bugfix]Fix Phi-3 BNB online quantization (#10417)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
284203f171
commit
7eb719df13
@ -470,7 +470,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
# Loaded weight is already fused on disk (mlp).
|
||||
# (e.g., Phi-3's gate_up_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
@ -480,6 +481,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
@ -495,7 +498,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
if use_bitsandbytes_4bit:
|
||||
shard_size = loaded_weight.shape[output_dim] // 2
|
||||
shard_offset = shard_size * shard_id
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
@ -808,7 +813,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
# Loaded weight is already fused on disk (qkv).
|
||||
# (e.g., Phi-3's qkv_proj).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
|
||||
@ -14,3 +14,13 @@ class Phi3ForCausalLM(LlamaForCausalLM):
|
||||
"gate_up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_up_proj.",
|
||||
".down_proj.",
|
||||
".qkv_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# Initialize an empty dict when there is no stacked parameter mapping.
|
||||
bitsandbytes_stacked_params_mapping = {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user