[Core] Refactor GGUF parameters packing and forwarding (#8859)

This commit is contained in:
Isotr0py 2024-10-07 18:01:46 +08:00 committed by GitHub
parent 4f95ffee6f
commit f19da64871
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 62 deletions

View File

@ -19,12 +19,12 @@ MAX_MODEL_LEN = 1024
# FIXME: Move this to confest # FIXME: Move this to confest
MODELS = [ MODELS = [
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")), filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
("TinyLlama/TinyLlama-1.1B-Chat-v1.0", ("meta-llama/Llama-3.2-1B-Instruct",
hf_hub_download("duyntnet/TinyLlama-1.1B-Chat-v1.0-imatrix-GGUF", hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
filename="TinyLlama-1.1B-Chat-v1.0-IQ4_XS.gguf")), filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
("Qwen/Qwen2-1.5B-Instruct", ("Qwen/Qwen2-1.5B-Instruct",
hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF", hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
filename="qwen2-1_5b-instruct-q4_k_m.gguf")), filename="qwen2-1_5b-instruct-q4_k_m.gguf")),

View File

@ -440,17 +440,23 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return return
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight:
from gguf.constants import GGML_QUANT_SIZES tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
ori_shape = param.tensor_shape output_dim = getattr(param, "output_dim", None)
weight_types = self.qweight_type.shard_weight_type.values() shard_size = loaded_weight.size(output_dim) // tp_size
row_size = [] start_idx = tp_rank * shard_size
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type] loaded_weight = loaded_weight.narrow(output_dim, start_idx,
row_size.append(ori_shape[1] // block_size * type_size) shard_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
@ -515,18 +521,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
@ -783,17 +777,23 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_weight_type[loaded_shard_id] = loaded_weight.item() param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return return
if is_gguf_weight and isinstance(param, UninitializedParameter): if is_gguf_weight:
from gguf.constants import GGML_QUANT_SIZES tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
ori_shape = param.tensor_shape output_dim = getattr(param, "output_dim", None)
weight_types = self.qweight_type.shard_weight_type.values() shard_size = loaded_weight.size(output_dim) // tp_size
row_size = [] start_idx = tp_rank * shard_size
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type] loaded_weight = loaded_weight.narrow(output_dim, start_idx,
row_size.append(ori_shape[1] // block_size * type_size) shard_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
@ -883,18 +883,6 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape
input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":

View File

@ -86,15 +86,16 @@ class GGUFLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition) tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = UninitializedParameter(requires_grad=False) qweight = GGUFUninitializedParameter(requires_grad=False)
set_weight_attrs( set_weight_attrs(
qweight, { qweight, {
"input_dim": 1, "input_dim": 1,
"output_dim": 0, "output_dim": 0,
"tensor_shape": tensor_shape, "tensor_shape": tensor_shape,
"is_gguf_weight": True, "is_gguf_weight": True,
"shard_size": {}, "data_container": [],
"shard_id": [], "shard_id": [],
"shard_id_map": {},
}) })
set_weight_attrs(qweight, extra_weight_attrs) set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
@ -116,21 +117,17 @@ class GGUFLinearMethod(LinearMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_size = getattr(layer.qweight, "shard_size", None)
shard_id = getattr(layer.qweight, "shard_id", None) shard_id = getattr(layer.qweight, "shard_id", None)
if shard_id and shard_size: if shard_id:
result = []
offset = 0
# dequantize shard weights respectively # dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
qweight = layer.qweight.unbind(0)
result = []
for id in shard_id: for id in shard_id:
shard_weight = layer.qweight[ q_idx = layer.qweight.shard_id_map[id]
offset:offset +
shard_size[id][0], :shard_size[id][1]].contiguous()
qweight_type = layer.qweight_type.shard_weight_type[id] qweight_type = layer.qweight_type.shard_weight_type[id]
result.append(_fuse_mul_mat(x, shard_weight, qweight_type)) result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type))
offset += shard_size[id][0]
out = torch.cat(result, axis=1) out = torch.cat(result, axis=1)
else: else:
qweight = layer.qweight qweight = layer.qweight
@ -162,3 +159,20 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]) x_flat.shape[0])
return dequant.view(*x.shape, hidden_size) return dequant.view(*x.shape, hidden_size)
class GGUFUninitializedParameter(UninitializedParameter):
cls_to_become = Parameter
data_container: List[torch.Tensor]
def materialize_nested(self) -> Parameter:
nested_data = torch.nested.nested_tensor(self.data_container,
device=self.device,
dtype=torch.uint8)
self.data_container.clear()
param = torch.Tensor._make_subclass(self.cls_to_become,
nested_data,
require_grad=False)
for k, v in self.__dict__.items():
setattr(param, k, v)
return param

View File

@ -512,7 +512,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
quant_config=quant_config, quant_config=quant_config,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head = self.model.embed_tokens
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,