[Misc]Reduce BNB static variable (#9987)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-05 01:04:40 +08:00 committed by GitHub
parent 8d72bb20fa
commit fb2716d641
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 20 additions and 46 deletions

View File

@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
@ -727,6 +728,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: List[str] = []
# we don't need to quantize the whole model, only the target modules # we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config # that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules. # file is not provided, we will quantize the default modules.
@ -744,8 +749,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
with open(config_file_path, "r") as f: with open(config_file_path, "r") as f:
config = json.load(f) config = json.load(f)
self.target_modules = config["target_modules"] self.target_modules = config["target_modules"]
# Save the module names without sharding.
self.unsharded_weights_modules: List[str] = []
def _get_config_file(self, qlora_adapter: str) -> str: def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter) is_local = os.path.isdir(qlora_adapter)
@ -971,9 +974,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for module in self.unsharded_weights_modules): for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor weight_sub_tensor = weight_tensor
# Shard by column # Shard by column
elif any(module in weight_name elif any(
for module in self.column_parallel_weights_modules): weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1) total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1) end_index = total_size // tp_size * (tp_rank + 1)
@ -1028,20 +1031,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
else: else:
self.target_modules = self.default_target_modules self.target_modules = self.default_target_modules
if hasattr(model, 'column_parallel_weights_modules'): for name, module in model.named_modules():
self.column_parallel_weights_modules = \
model.column_parallel_weights_modules
else:
self.column_parallel_weights_modules = []
# Some modules like `ReplicatedLinear` should not have their weights # Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new # sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation. # static variable in the model implementation.
# TODO: Can we reduce the static variables needed for BNB based on if isinstance(module, (ReplicatedLinear, )):
# model information? self.unsharded_weights_modules.append(name)
self.unsharded_weights_modules = [ # In TP, these weights are partitioned along the column
name for name, module in model.named_modules() # dimension (dim=-1)
if isinstance(module, (ReplicatedLinear, )) elif isinstance(module, (RowParallelLinear, )):
] self.column_sharded_weights_modules.append(name)
self.model_type = type(model).__name__ self.model_type = type(model).__name__
logger.info("Loading weights with BitsAndBytes quantization. " logger.info("Loading weights with BitsAndBytes quantization. "

View File

@ -401,8 +401,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
".dense_h_to_4h.", ".dense_h_to_4h.",
".dense_4h_to_h.", ".dense_4h_to_h.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
def __init__( def __init__(
self, self,

View File

@ -350,7 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"gate_up_proj", "gate_up_proj",
"down_proj", "down_proj",
] ]
# BitandBytes specific attributes # BitandBytes specific attributes
default_bitsandbytes_target_modules = [ default_bitsandbytes_target_modules = [
".gate_proj.", ".gate_proj.",
@ -361,8 +360,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.", ".v_proj.",
".o_proj.", ".o_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -390,8 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.", ".v_proj.",
".o_proj.", ".o_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -464,8 +464,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.", ".v_proj.",
".o_proj.", ".o_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -854,10 +854,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
# resampler # resampler
".kv_proj.", ".kv_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),
@ -1008,10 +1004,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
# resampler # resampler
".kv_proj.", ".kv_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -1062,8 +1062,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# so we can't add a dot in front of it. # so we can't add a dot in front of it.
"multi_modal_projector." "multi_modal_projector."
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),

View File

@ -343,8 +343,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
default_bitsandbytes_target_modules = [ default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".out_proj.", ".fc2."]
def __init__( def __init__(
self, self,

View File

@ -274,8 +274,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
default_bitsandbytes_target_modules = [ default_bitsandbytes_target_modules = [
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".fc2.", ".dense."]
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []

View File

@ -395,9 +395,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
".v_proj.", ".v_proj.",
".o_proj.", ".o_proj.",
] ]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index # shard_name, weight_name, index
"q_proj": ("qkv_proj", 0), "q_proj": ("qkv_proj", 0),