[Misc]Reduce BNB static variable (#9987)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
8d72bb20fa
commit
fb2716d641
@ -28,7 +28,8 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import (ReplicatedLinear,
|
||||||
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.model_loader.tensorizer import (
|
from vllm.model_executor.model_loader.tensorizer import (
|
||||||
@ -727,6 +728,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
super().__init__(load_config)
|
super().__init__(load_config)
|
||||||
|
|
||||||
|
# Save the module names without sharding.
|
||||||
|
self.unsharded_weights_modules: List[str] = []
|
||||||
|
# Save the module names that are sharded by column.
|
||||||
|
self.column_sharded_weights_modules: List[str] = []
|
||||||
# we don't need to quantize the whole model, only the target modules
|
# we don't need to quantize the whole model, only the target modules
|
||||||
# that are specified in the adapter config file. If the adapter config
|
# that are specified in the adapter config file. If the adapter config
|
||||||
# file is not provided, we will quantize the default modules.
|
# file is not provided, we will quantize the default modules.
|
||||||
@ -744,8 +749,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
with open(config_file_path, "r") as f:
|
with open(config_file_path, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
self.target_modules = config["target_modules"]
|
self.target_modules = config["target_modules"]
|
||||||
# Save the module names without sharding.
|
|
||||||
self.unsharded_weights_modules: List[str] = []
|
|
||||||
|
|
||||||
def _get_config_file(self, qlora_adapter: str) -> str:
|
def _get_config_file(self, qlora_adapter: str) -> str:
|
||||||
is_local = os.path.isdir(qlora_adapter)
|
is_local = os.path.isdir(qlora_adapter)
|
||||||
@ -971,9 +974,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
for module in self.unsharded_weights_modules):
|
for module in self.unsharded_weights_modules):
|
||||||
weight_sub_tensor = weight_tensor
|
weight_sub_tensor = weight_tensor
|
||||||
# Shard by column
|
# Shard by column
|
||||||
elif any(module in weight_name
|
elif any(
|
||||||
for module in self.column_parallel_weights_modules):
|
weight_name.startswith(module)
|
||||||
|
for module in self.column_sharded_weights_modules):
|
||||||
total_size = weight_tensor.size(-1)
|
total_size = weight_tensor.size(-1)
|
||||||
start_index = total_size // tp_size * tp_rank
|
start_index = total_size // tp_size * tp_rank
|
||||||
end_index = total_size // tp_size * (tp_rank + 1)
|
end_index = total_size // tp_size * (tp_rank + 1)
|
||||||
@ -1028,20 +1031,17 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
else:
|
else:
|
||||||
self.target_modules = self.default_target_modules
|
self.target_modules = self.default_target_modules
|
||||||
|
|
||||||
if hasattr(model, 'column_parallel_weights_modules'):
|
for name, module in model.named_modules():
|
||||||
self.column_parallel_weights_modules = \
|
|
||||||
model.column_parallel_weights_modules
|
|
||||||
else:
|
|
||||||
self.column_parallel_weights_modules = []
|
|
||||||
# Some modules like `ReplicatedLinear` should not have their weights
|
# Some modules like `ReplicatedLinear` should not have their weights
|
||||||
# sharded. The reason for implementing it this way is to avoid new
|
# sharded. The reason for implementing it this way is to avoid new
|
||||||
# static variable in the model implementation.
|
# static variable in the model implementation.
|
||||||
# TODO: Can we reduce the static variables needed for BNB based on
|
if isinstance(module, (ReplicatedLinear, )):
|
||||||
# model information?
|
self.unsharded_weights_modules.append(name)
|
||||||
self.unsharded_weights_modules = [
|
# In TP, these weights are partitioned along the column
|
||||||
name for name, module in model.named_modules()
|
# dimension (dim=-1)
|
||||||
if isinstance(module, (ReplicatedLinear, ))
|
elif isinstance(module, (RowParallelLinear, )):
|
||||||
]
|
self.column_sharded_weights_modules.append(name)
|
||||||
|
|
||||||
self.model_type = type(model).__name__
|
self.model_type = type(model).__name__
|
||||||
|
|
||||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||||
|
|||||||
@ -401,8 +401,6 @@ class FalconForCausalLM(nn.Module, SupportsPP):
|
|||||||
".dense_h_to_4h.",
|
".dense_h_to_4h.",
|
||||||
".dense_4h_to_h.",
|
".dense_4h_to_h.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -350,7 +350,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
"gate_up_proj",
|
"gate_up_proj",
|
||||||
"down_proj",
|
"down_proj",
|
||||||
]
|
]
|
||||||
|
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
default_bitsandbytes_target_modules = [
|
||||||
".gate_proj.",
|
".gate_proj.",
|
||||||
@ -361,8 +360,6 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -390,8 +390,6 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -464,8 +464,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -854,10 +854,6 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
# resampler
|
# resampler
|
||||||
".kv_proj.",
|
".kv_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [
|
|
||||||
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
@ -1008,10 +1004,6 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
|
|||||||
# resampler
|
# resampler
|
||||||
".kv_proj.",
|
".kv_proj.",
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [
|
|
||||||
".down_proj.", ".o_proj.", ".self_attn.out_proj.", ".fc2."
|
|
||||||
]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -1062,8 +1062,6 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
# so we can't add a dot in front of it.
|
# so we can't add a dot in front of it.
|
||||||
"multi_modal_projector."
|
"multi_modal_projector."
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj.", ".fc2."]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
@ -343,8 +343,6 @@ class OPTForCausalLM(nn.Module, SupportsPP):
|
|||||||
default_bitsandbytes_target_modules = [
|
default_bitsandbytes_target_modules = [
|
||||||
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2."
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".out_proj.", ".fc2."]
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -274,8 +274,6 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
default_bitsandbytes_target_modules = [
|
default_bitsandbytes_target_modules = [
|
||||||
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense."
|
||||||
]
|
]
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".fc2.", ".dense."]
|
|
||||||
|
|
||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
|
|||||||
@ -395,9 +395,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
".v_proj.",
|
".v_proj.",
|
||||||
".o_proj.",
|
".o_proj.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
|
||||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
|
||||||
bitsandbytes_stacked_params_mapping = {
|
bitsandbytes_stacked_params_mapping = {
|
||||||
# shard_name, weight_name, index
|
# shard_name, weight_name, index
|
||||||
"q_proj": ("qkv_proj", 0),
|
"q_proj": ("qkv_proj", 0),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user