[Bugfix] Fix MiniCPMV and Mllama BNB bug (#9917)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
91c9ebbb1b
commit
c49f0407ba
@ -41,6 +41,7 @@ from torch import nn
|
|||||||
from torch.nn.init import trunc_normal_
|
from torch.nn.init import trunc_normal_
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
|
||||||
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
||||||
|
|
||||||
@ -154,15 +155,15 @@ class BaseResampler(nn.Module):
|
|||||||
A tensor with the shape of (grid_size**2, embed_dim)
|
A tensor with the shape of (grid_size**2, embed_dim)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
num_queries: int,
|
||||||
num_queries: int,
|
embed_dim: int,
|
||||||
embed_dim: int,
|
num_heads: int,
|
||||||
num_heads: int,
|
kv_dim: Optional[int] = None,
|
||||||
kv_dim: Optional[int] = None,
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
do_post_projection: bool = True,
|
||||||
do_post_projection: bool = True,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
prefix: str = "") -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_queries = num_queries
|
self.num_queries = num_queries
|
||||||
@ -172,7 +173,11 @@ class BaseResampler(nn.Module):
|
|||||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
||||||
trunc_normal_(self.query, std=0.02)
|
trunc_normal_(self.query, std=0.02)
|
||||||
if kv_dim is not None and kv_dim != embed_dim:
|
if kv_dim is not None and kv_dim != embed_dim:
|
||||||
self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False)
|
self.kv_proj = ReplicatedLinear(kv_dim,
|
||||||
|
embed_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
else:
|
else:
|
||||||
# Maintain the same return value with ReplicatedLinear.forward
|
# Maintain the same return value with ReplicatedLinear.forward
|
||||||
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa
|
||||||
@ -209,22 +214,24 @@ class Resampler2(BaseResampler):
|
|||||||
present in minicpmv2.0, but not qwen-vl.
|
present in minicpmv2.0, but not qwen-vl.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
grid_size: int,
|
||||||
grid_size: int,
|
embed_dim: int,
|
||||||
embed_dim: int,
|
num_heads: int,
|
||||||
num_heads: int,
|
kv_dim: Optional[int] = None,
|
||||||
kv_dim: Optional[int] = None,
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
adaptive: bool = False,
|
||||||
adaptive: bool = False,
|
do_post_projection: bool = True,
|
||||||
do_post_projection: bool = True,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
prefix: str = "") -> None:
|
||||||
super().__init__(grid_size**2,
|
super().__init__(grid_size**2,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
num_heads,
|
num_heads,
|
||||||
kv_dim,
|
kv_dim,
|
||||||
norm_layer,
|
norm_layer,
|
||||||
do_post_projection=do_post_projection)
|
do_post_projection=do_post_projection,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
self.adaptive = adaptive
|
self.adaptive = adaptive
|
||||||
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
pos_embed_arr = get_2d_sincos_pos_embed(embed_dim,
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader.tensorizer import (
|
from vllm.model_executor.model_loader.tensorizer import (
|
||||||
@ -771,6 +772,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
with open(config_file_path, "r") as f:
|
with open(config_file_path, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
self.target_modules = config["target_modules"]
|
self.target_modules = config["target_modules"]
|
||||||
|
# Save the module names without sharding.
|
||||||
|
self.unsharded_weights_modules: List[str] = []
|
||||||
|
|
||||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
def _get_config_file(self, qlora_adapter: str) -> str:
|
||||||
is_local = os.path.isdir(qlora_adapter)
|
is_local = os.path.isdir(qlora_adapter)
|
||||||
@ -990,16 +993,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
if any(target_module in weight_name for target_module in
|
if any(target_module in weight_name for target_module in
|
||||||
self.target_modules) and weight_name.endswith(".weight"):
|
self.target_modules) and weight_name.endswith(".weight"):
|
||||||
weight_name = weight_name.replace(".weight", ".qweight")
|
weight_name = weight_name.replace(".weight", ".qweight")
|
||||||
|
# Without sharding
|
||||||
if any(module in weight_name
|
if any(
|
||||||
for module in self.column_parallel_weights_modules):
|
weight_name.startswith(module)
|
||||||
|
for module in self.unsharded_weights_modules):
|
||||||
|
weight_sub_tensor = weight_tensor
|
||||||
|
# Shard by column
|
||||||
|
elif any(module in weight_name
|
||||||
|
for module in self.column_parallel_weights_modules):
|
||||||
|
|
||||||
total_size = weight_tensor.size(-1)
|
total_size = weight_tensor.size(-1)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
end_index = total_size // tp_size * (tp_rank + 1)
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
weight_sub_tensor = weight_tensor[...,
|
weight_sub_tensor = weight_tensor[...,
|
||||||
start_index:end_index]
|
start_index:end_index]
|
||||||
|
# Shard by row
|
||||||
else:
|
else:
|
||||||
total_size = weight_tensor.size(0)
|
total_size = weight_tensor.size(0)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
@ -1053,7 +1061,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
model.column_parallel_weights_modules
|
model.column_parallel_weights_modules
|
||||||
else:
|
else:
|
||||||
self.column_parallel_weights_modules = []
|
self.column_parallel_weights_modules = []
|
||||||
|
# Some modules like `ReplicatedLinear` should not have their weights
|
||||||
|
# sharded. The reason for implementing it this way is to avoid new
|
||||||
|
# static variable in the model implementation.
|
||||||
|
# TODO: Can we reduce the static variables needed for BNB based on
|
||||||
|
# model information?
|
||||||
|
self.unsharded_weights_modules = [
|
||||||
|
name for name, module in model.named_modules()
|
||||||
|
if isinstance(module, (ReplicatedLinear, ))
|
||||||
|
]
|
||||||
self.model_type = type(model).__name__
|
self.model_type = type(model).__name__
|
||||||
|
|
||||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||||
@ -1100,7 +1116,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
for shard_name, (
|
for shard_name, (
|
||||||
weight_name, index
|
weight_name, index
|
||||||
) in model.bitsandbytes_stacked_params_mapping.items():
|
) in model.bitsandbytes_stacked_params_mapping.items():
|
||||||
if shard_name in quant_param_name:
|
|
||||||
|
shard_pos = quant_param_name.find(shard_name)
|
||||||
|
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||||
|
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||||
|
# from being incorrectly identified as being present in
|
||||||
|
# 'vpm.encoder.layers.0.self_attn.qkv_proj.qweight
|
||||||
|
if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".":
|
||||||
shard_index = index
|
shard_index = index
|
||||||
quant_param_name = quant_param_name.replace(
|
quant_param_name = quant_param_name.replace(
|
||||||
shard_name, weight_name)
|
shard_name, weight_name)
|
||||||
|
|||||||
@ -131,16 +131,22 @@ DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6)
|
|||||||
|
|
||||||
class Resampler2_5(BaseResampler):
|
class Resampler2_5(BaseResampler):
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
num_queries: int,
|
||||||
num_queries: int,
|
embed_dim: int,
|
||||||
embed_dim: int,
|
num_heads: int,
|
||||||
num_heads: int,
|
kv_dim: Optional[int] = None,
|
||||||
kv_dim: Optional[int] = None,
|
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
||||||
norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN,
|
max_size: Tuple[int, int] = (70, 70),
|
||||||
max_size: Tuple[int, int] = (70, 70),
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
prefix: str = "") -> None:
|
||||||
super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer)
|
super().__init__(num_queries,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
kv_dim,
|
||||||
|
norm_layer,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self._set_2d_pos_cache(self.max_size)
|
self._set_2d_pos_cache(self.max_size)
|
||||||
@ -404,7 +410,10 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else
|
||||||
self.vpm.embeddings.embed_dim)
|
self.vpm.embeddings.embed_dim)
|
||||||
self.embed_dim = self.config.hidden_size
|
self.embed_dim = self.config.hidden_size
|
||||||
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
|
self.resampler = self.init_resampler(self.embed_dim,
|
||||||
|
self.vision_dim,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix="resampler")
|
||||||
self.resampler.to(device="cuda", dtype=param_dtype)
|
self.resampler.to(device="cuda", dtype=param_dtype)
|
||||||
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
# TODO: why is there _KEYS_TO_MODIFY_MAPPING? lm_head should be in llm
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
@ -666,7 +675,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
def init_resampler(self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_vision_embedding(
|
def get_vision_embedding(
|
||||||
@ -743,16 +756,21 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
|
|||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.embed_tokens(input_ids)
|
return self.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
def init_resampler(self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
with set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
resampler = Resampler2(
|
resampler = Resampler2(embed_dim=embed_dim,
|
||||||
embed_dim=embed_dim,
|
num_heads=embed_dim // 128,
|
||||||
num_heads=embed_dim // 128,
|
grid_size=int(
|
||||||
grid_size=int(math.sqrt(self.config.query_num)),
|
math.sqrt(self.config.query_num)),
|
||||||
kv_dim=vision_dim,
|
kv_dim=vision_dim,
|
||||||
adaptive=False,
|
adaptive=False,
|
||||||
do_post_projection=True,
|
do_post_projection=True,
|
||||||
)
|
quant_config=quant_config,
|
||||||
|
prefix=prefix)
|
||||||
|
|
||||||
return resampler
|
return resampler
|
||||||
|
|
||||||
@ -825,9 +843,21 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
".k_proj.",
|
".k_proj.",
|
||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
|
# vision encoder
|
||||||
|
".fc1.",
|
||||||
|
".fc2.",
|
||||||
|
# Currently, vllm does not support BNB quantization for the `out_proj`
|
||||||
|
# of the resampler, so it's necessary to distinguish between the
|
||||||
|
# vision encoder and the resampler's out_proj. The same applies to
|
||||||
|
# MiniCPMV2_6.
|
||||||
|
".self_attn.out_proj.", # vision encoder out_proj
|
||||||
|
# resampler
|
||||||
|
".kv_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
column_parallel_weights_modules = [
|
||||||
|
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
||||||
|
]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
@ -877,14 +907,18 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
model.encoder.layers = model.encoder.layers[:-1]
|
model.encoder.layers = model.encoder.layers[:-1]
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
def init_resampler(self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
with set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
resampler = Resampler2_5(
|
resampler = Resampler2_5(num_queries=self.config.query_num,
|
||||||
num_queries=self.config.query_num,
|
embed_dim=embed_dim,
|
||||||
embed_dim=embed_dim,
|
num_heads=embed_dim // 128,
|
||||||
num_heads=embed_dim // 128,
|
kv_dim=vision_dim,
|
||||||
kv_dim=vision_dim,
|
quant_config=quant_config,
|
||||||
)
|
prefix=prefix)
|
||||||
return resampler
|
return resampler
|
||||||
|
|
||||||
def get_vision_embedding(
|
def get_vision_embedding(
|
||||||
@ -967,9 +1001,17 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
".k_proj.",
|
".k_proj.",
|
||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
|
# vision encoder
|
||||||
|
".fc1.",
|
||||||
|
".fc2.",
|
||||||
|
".self_attn.out_proj.",
|
||||||
|
# resampler
|
||||||
|
".kv_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
column_parallel_weights_modules = [
|
||||||
|
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
||||||
|
]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
@ -1019,15 +1061,19 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
model.encoder.layers = model.encoder.layers[:-1]
|
model.encoder.layers = model.encoder.layers[:-1]
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module:
|
def init_resampler(self,
|
||||||
|
embed_dim: int,
|
||||||
|
vision_dim: int,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
with set_default_torch_dtype(torch.float16):
|
with set_default_torch_dtype(torch.float16):
|
||||||
# The resampler in 2.6 remains consistent with the one in 2.5.
|
# The resampler in 2.6 remains consistent with the one in 2.5.
|
||||||
resampler = Resampler2_5(
|
resampler = Resampler2_5(num_queries=self.config.query_num,
|
||||||
num_queries=self.config.query_num,
|
embed_dim=embed_dim,
|
||||||
embed_dim=embed_dim,
|
num_heads=embed_dim // 128,
|
||||||
num_heads=embed_dim // 128,
|
kv_dim=vision_dim,
|
||||||
kv_dim=vision_dim,
|
quant_config=quant_config,
|
||||||
)
|
prefix=prefix)
|
||||||
return resampler
|
return resampler
|
||||||
|
|
||||||
def get_vision_embedding(
|
def get_vision_embedding(
|
||||||
|
|||||||
@ -1056,9 +1056,14 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
".k_proj.",
|
".k_proj.",
|
||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
|
".fc1.",
|
||||||
|
".fc2.",
|
||||||
|
# The `multi_modal_projector` is at the top level of the model,
|
||||||
|
# so we can't add a dot in front of it.
|
||||||
|
"multi_modal_projector."
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user