[Core] Refactor GGUF parameters packing and forwarding (#8859)
This commit is contained in:
parent
4f95ffee6f
commit
f19da64871
@ -19,12 +19,12 @@ MAX_MODEL_LEN = 1024
|
|||||||
|
|
||||||
# FIXME: Move this to confest
|
# FIXME: Move this to confest
|
||||||
MODELS = [
|
MODELS = [
|
||||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
("meta-llama/Llama-3.2-1B-Instruct",
|
||||||
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
|
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||||
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")),
|
filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
|
||||||
("TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
("meta-llama/Llama-3.2-1B-Instruct",
|
||||||
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF",
|
hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
|
||||||
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")),
|
filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
|
||||||
("Qwen/Qwen2-1.5B-Instruct",
|
("Qwen/Qwen2-1.5B-Instruct",
|
||||||
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
|
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
|
||||||
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
|
filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
|
||||||
|
|||||||
@ -440,17 +440,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
if is_gguf_weight:
|
||||||
from gguf.constants import GGML_QUANT_SIZES
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
ori_shape = param.tensor_shape
|
output_dim = getattr(param, "output_dim", None)
|
||||||
weight_types = self.qweight_type.shard_weight_type.values()
|
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||||
row_size = []
|
start_idx = tp_rank * shard_size
|
||||||
for weight_type in weight_types:
|
|
||||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
row_size.append(ori_shape[1] // block_size * type_size)
|
shard_size)
|
||||||
q_shape = (ori_shape[0], max(row_size))
|
|
||||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
param.shard_id.append(loaded_shard_id)
|
||||||
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||||
|
param.data_container.append(loaded_weight)
|
||||||
|
if len(param.data_container) == 2:
|
||||||
|
self.qweight = param.materialize_nested()
|
||||||
|
return
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
@ -515,18 +521,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
shard_offset = loaded_weight.shape[output_dim] * \
|
shard_offset = loaded_weight.shape[output_dim] * \
|
||||||
loaded_shard_id
|
loaded_shard_id
|
||||||
|
|
||||||
if is_gguf_weight:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
shard_shape = list(loaded_weight.shape)
|
|
||||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
|
||||||
param.shard_id.append(loaded_shard_id)
|
|
||||||
param.shard_size[loaded_shard_id] = shard_shape
|
|
||||||
|
|
||||||
input_dim = getattr(param, "input_dim", None)
|
|
||||||
input_size = loaded_weight.shape[input_dim]
|
|
||||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
@ -783,17 +777,23 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
if is_gguf_weight:
|
||||||
from gguf.constants import GGML_QUANT_SIZES
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
ori_shape = param.tensor_shape
|
output_dim = getattr(param, "output_dim", None)
|
||||||
weight_types = self.qweight_type.shard_weight_type.values()
|
shard_size = loaded_weight.size(output_dim) // tp_size
|
||||||
row_size = []
|
start_idx = tp_rank * shard_size
|
||||||
for weight_type in weight_types:
|
|
||||||
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||||
row_size.append(ori_shape[1] // block_size * type_size)
|
shard_size)
|
||||||
q_shape = (ori_shape[0], max(row_size))
|
|
||||||
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
param.shard_id.append(loaded_shard_id)
|
||||||
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
||||||
|
param.data_container.append(loaded_weight)
|
||||||
|
if len(param.data_container) == 3:
|
||||||
|
self.qweight = param.materialize_nested()
|
||||||
|
return
|
||||||
|
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
@ -883,18 +883,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
||||||
param, orig_qkv_offsets, loaded_shard_id)
|
param, orig_qkv_offsets, loaded_shard_id)
|
||||||
|
|
||||||
if is_gguf_weight:
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
shard_shape = list(loaded_weight.shape)
|
|
||||||
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
|
|
||||||
param.shard_id.append(loaded_shard_id)
|
|
||||||
param.shard_size[loaded_shard_id] = shard_shape
|
|
||||||
|
|
||||||
input_dim = getattr(param, "input_dim", None)
|
|
||||||
input_size = loaded_weight.shape[input_dim]
|
|
||||||
param_data = param_data.narrow(input_dim, 0, input_size)
|
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
if loaded_shard_id == "q":
|
if loaded_shard_id == "q":
|
||||||
|
|||||||
@ -86,15 +86,16 @@ class GGUFLinearMethod(LinearMethodBase):
|
|||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
tensor_shape = (output_size_per_partition, input_size_per_partition)
|
||||||
qweight = UninitializedParameter(requires_grad=False)
|
qweight = GGUFUninitializedParameter(requires_grad=False)
|
||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
qweight, {
|
qweight, {
|
||||||
"input_dim": 1,
|
"input_dim": 1,
|
||||||
"output_dim": 0,
|
"output_dim": 0,
|
||||||
"tensor_shape": tensor_shape,
|
"tensor_shape": tensor_shape,
|
||||||
"is_gguf_weight": True,
|
"is_gguf_weight": True,
|
||||||
"shard_size": {},
|
"data_container": [],
|
||||||
"shard_id": [],
|
"shard_id": [],
|
||||||
|
"shard_id_map": {},
|
||||||
})
|
})
|
||||||
set_weight_attrs(qweight, extra_weight_attrs)
|
set_weight_attrs(qweight, extra_weight_attrs)
|
||||||
layer.register_parameter("qweight", qweight)
|
layer.register_parameter("qweight", qweight)
|
||||||
@ -116,21 +117,17 @@ class GGUFLinearMethod(LinearMethodBase):
|
|||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
shard_size = getattr(layer.qweight, "shard_size", None)
|
|
||||||
shard_id = getattr(layer.qweight, "shard_id", None)
|
shard_id = getattr(layer.qweight, "shard_id", None)
|
||||||
|
|
||||||
if shard_id and shard_size:
|
if shard_id:
|
||||||
result = []
|
|
||||||
offset = 0
|
|
||||||
# dequantize shard weights respectively
|
# dequantize shard weights respectively
|
||||||
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
|
||||||
|
qweight = layer.qweight.unbind(0)
|
||||||
|
result = []
|
||||||
for id in shard_id:
|
for id in shard_id:
|
||||||
shard_weight = layer.qweight[
|
q_idx = layer.qweight.shard_id_map[id]
|
||||||
offset:offset +
|
|
||||||
shard_size[id][0], :shard_size[id][1]].contiguous()
|
|
||||||
qweight_type = layer.qweight_type.shard_weight_type[id]
|
qweight_type = layer.qweight_type.shard_weight_type[id]
|
||||||
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
|
result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
|
||||||
offset += shard_size[id][0]
|
|
||||||
out = torch.cat(result, axis=1)
|
out = torch.cat(result, axis=1)
|
||||||
else:
|
else:
|
||||||
qweight = layer.qweight
|
qweight = layer.qweight
|
||||||
@ -162,3 +159,20 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
|
|||||||
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
|
||||||
x_flat.shape[0])
|
x_flat.shape[0])
|
||||||
return dequant.view(*x.shape, hidden_size)
|
return dequant.view(*x.shape, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
class GGUFUninitializedParameter(UninitializedParameter):
|
||||||
|
cls_to_become = Parameter
|
||||||
|
data_container: List[torch.Tensor]
|
||||||
|
|
||||||
|
def materialize_nested(self) -> Parameter:
|
||||||
|
nested_data = torch.nested.nested_tensor(self.data_container,
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
self.data_container.clear()
|
||||||
|
param = torch.Tensor._make_subclass(self.cls_to_become,
|
||||||
|
nested_data,
|
||||||
|
require_grad=False)
|
||||||
|
for k, v in self.__dict__.items():
|
||||||
|
setattr(param, k, v)
|
||||||
|
return param
|
||||||
|
|||||||
@ -512,7 +512,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
self.lm_head.weight = self.model.embed_tokens.weight
|
self.lm_head = self.model.embed_tokens
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user