From 8a06428c70657b3310a317b3caf3c562b0e042ae Mon Sep 17 00:00:00 2001 From: Umesh Date: Tue, 12 Nov 2024 11:08:40 -0800 Subject: [PATCH] [LoRA] Adds support for bias in LoRA (#5733) Signed-off-by: Umesh Deshpande Co-authored-by: Umesh Deshpande --- tests/lora/conftest.py | 5 + tests/lora/test_lora_bias_e2e.py | 52 ++++++ tests/lora/test_utils.py | 14 +- vllm/config.py | 1 + vllm/engine/arg_utils.py | 5 + vllm/lora/fully_sharded_layers.py | 33 ++++ vllm/lora/layers.py | 296 +++++++++++++++++++++++++++++- vllm/lora/lora.py | 17 +- vllm/lora/models.py | 36 +++- vllm/lora/utils.py | 17 +- 10 files changed, 456 insertions(+), 20 deletions(-) create mode 100644 tests/lora/test_lora_bias_e2e.py diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 816d3986..29ecf378 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) +@pytest.fixture(scope="session") +def lora_bias_files(): + return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias") + + @pytest.fixture(scope="session") def mixtral_lora_files(): # Note: this module has incorrect adapter_config.json to test diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py new file mode 100644 index 00000000..c2520c84 --- /dev/null +++ b/tests/lora/test_lora_bias_e2e.py @@ -0,0 +1,52 @@ +from typing import List + +import pytest + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "ibm-granite/granite-3b-code-base" + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + generated_texts: List[str] = [] + for output in outputs: + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + return generated_texts + + +@pytest.mark.parametrize("lora_bias", [True]) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_lora_rank=8, + max_loras=1, + enable_lora_bias=lora_bias, + tensor_parallel_size=1, + fully_sharded_loras=fully_sharded) + + print("lora adapter created") + output1 = do_sample(llm, lora_bias_files, lora_id=0) + + print("lora") + output2 = do_sample(llm, lora_bias_files, lora_id=1) + + if lora_bias: + assert output1 != output2 + else: + assert output1 == output2 diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index db02bacd..85110b8f 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -12,36 +12,40 @@ from vllm.utils import LRUCache def test_parse_fine_tuned_lora_name_valid(): fixture = { - ("base_model.model.lm_head.lora_A.weight", "lm_head", True), - ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), ( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, + False, ), ( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, + False, ), ( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, + False, ), ( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, + False, ), } - for name, module_name, is_lora_a in fixture: - assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + for name, module_name, is_lora_a, is_bias in fixture: + assert (module_name, is_lora_a, + is_bias) == parse_fine_tuned_lora_name(name) def test_parse_fine_tuned_lora_name_invalid(): fixture = { - "weight", "base_model.weight", "base_model.model.weight", } diff --git a/vllm/config.py b/vllm/config.py index b354fb61..5ba1c41f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1687,6 +1687,7 @@ class LoRAConfig: # This is a constant. lora_vocab_padding_size: ClassVar[int] = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None + bias_enabled: bool = False def __post_init__(self): # Setting the maximum rank to 256 should be able to satisfy the vast diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 1591059a..27f62b00 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -143,6 +143,7 @@ class EngineArgs: limit_mm_per_prompt: Optional[Mapping[str, int]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None enable_lora: bool = False + enable_lora_bias: bool = False max_loras: int = 1 max_lora_rank: int = 16 enable_prompt_adapter: bool = False @@ -584,6 +585,9 @@ class EngineArgs: parser.add_argument('--enable-lora', action='store_true', help='If True, enable handling of LoRA adapters.') + parser.add_argument('--enable-lora-bias', + action='store_true', + help='If True, enable bias for LoRA adapters.') parser.add_argument('--max-loras', type=int, default=EngineArgs.max_loras, @@ -1148,6 +1152,7 @@ class EngineArgs: and parallel_config.use_ray), policy=self.scheduling_policy) lora_config = LoRAConfig( + bias_enabled=self.enable_lora_bias, max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, fully_sharded_loras=self.fully_sharded_loras, diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index a7887a04..04fc6358 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): self.lora_b_stacked, add_input=True) # now have column partitioned output + + if self.bias_stacked is not None: + self.bias_stacked = self.bias_stacked.view( + -1, self.bias_stacked.shape[-1]) + self.bias_stacked = self.bias_stacked[ + self.punica_wrapper.token_lora_indices] + output += self.bias_stacked + output = output.view(*out_orig_shape) return output @@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): left_offset = 0 for idx in range(n): shard_size = layer.lora_b_stacked[idx].shape[2] + + if layer.bias_stacked is not None: + bias = layer.bias_stacked[idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[layer.punica_wrapper.token_lora_indices] + bias[layer.punica_wrapper.token_lora_indices == -1] = 0 + output[:, left_offset:left_offset + shard_size] += bias + layer.punica_wrapper.add_expand_slice( output, buffers[idx], @@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): lora_b = lora_b[:, start_idx:end_idx] return lora_b + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + shard_size = self.bias_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) @@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): # reduced before being used shard_size = self.lora_b_stacked.shape[2] start_idx = self.tp_rank * shard_size + + if self.bias_stacked is not None: + bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1]) + bias = bias[self.punica_wrapper.token_lora_indices] + bias[self.punica_wrapper.token_lora_indices == -1] = 0 + output += bias + self.punica_wrapper.add_expand_slice(output, buffer, self.lora_b_stacked, start_idx, shard_size) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6254c675..7429c60e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace): return dec +def apply_bias( + indices: torch.Tensor, + output: torch.Tensor, + bias_stacked: torch.Tensor, +): + """Applies bias to output + + Input shapes: + bias_stacked: (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) + bias_stacked = bias_stacked[indices] + bias_stacked[indices == -1] = 0 + output += bias_stacked + + return output.view_as(org_output) + + +def apply_bias_packed_nslice( + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], +): + """Applies bias to output + + Input shapes: + bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + + offset_left += slice + + return output.view_as(org_output) + + @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False @@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module): lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): """Overwrites lora tensors at index.""" ... @@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( @@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) + if lora_config.bias_enabled: + self.bias_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + else: + self.bias_stacked = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + if self.lora_config.bias_enabled: + self.bias_stacked[index] = 0 def set_lora( self, @@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) @@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) + if bias is not None: + self.bias_stacked[index, + 0, :bias.shape[0]].copy_(bias.T, + non_blocking=True) def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + if self.bias_stacked is not None: + self.indices = self.punica_wrapper.token_lora_indices + output = apply_bias( + self.indices, + output, + self.bias_stacked, + ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0) return output @@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) + + if lora_config.bias_enabled: + self.bias_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + else: + self.bias_stacked = None + self.output_dim = self.lora_b_stacked.shape[2] def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + if self.lora_config.bias_enabled: + self.bias_stacked[index] = 0 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: return lora_a @@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): lora_b = lora_b[:, start_idx:end_idx] return lora_b + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + if bias is None: + return bias + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + bias = bias[start_idx:end_idx] + return bias + def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) + bias = self.slice_bias(bias) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) + if bias is not None: + self.bias_stacked[index, + 0, :bias.shape[0]].copy_(bias.T, + non_blocking=True) def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + if self.bias_stacked is not None: + self.indices = self.punica_wrapper.token_lora_indices + output = apply_bias( + self.indices, + output, + self.bias_stacked, + ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0) return output @@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) for _ in range(n_slices)) + if lora_config.bias_enabled: + self.bias_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.output_size // 2, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(n_slices)) + else: + self.bias_stacked = None self.output_dim = self.lora_b_stacked[0].shape[2] @@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.lora_a_stacked[1][index] = 0 self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 + if self.lora_config.bias_enabled: + self.bias_stacked[0][index] = 0 + self.bias_stacked[1][index] = 0 def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] @@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ] return lora_b + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + if bias[0] is None or bias[1] is None: + return bias + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]] + return bias + def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) + if bias is not None: + bias = self.slice_bias(bias) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.lora_b_stacked[0][ index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( lora_b[0].T, non_blocking=True) + if bias is not None and bias[0] is not None: + self.bias_stacked[0][index, + 0, :bias[0].shape[0]].copy_(bias[0].T, + non_blocking=True) if lora_a[1] is not None: self.lora_a_stacked[1][ index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( @@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.lora_b_stacked[1][ index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) + if bias is not None and bias[1] is not None: + self.bias_stacked[1][index, + 0, :bias[1].shape[0]].copy_(bias[1].T, + non_blocking=True) def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + if self.bias_stacked is not None: + self.indices = self.punica_wrapper.token_lora_indices + output = apply_bias_packed_nslice( + self.indices, + output, + (self.output_dim, self.output_dim), + self.bias_stacked, + ) self.punica_wrapper.add_lora_packed_nslice( output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, (self.output_dim, self.output_dim)) @@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) return lora_b + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + bias_q = bias[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + bias_k = bias[k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + bias_v = bias[v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + bias = torch.cat([bias_q, bias_k, bias_v], dim=1) + return bias + def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) + if bias is not None: + bias = self.slice_bias(bias) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) + if bias is not None: + self.bias_stacked[index, + 0, :bias.shape[0]].copy_(bias.T, + non_blocking=True) @classmethod @_not_fully_sharded_can_replace @@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): device=self.device, ), ) + if lora_config.bias_enabled: + self.bias_stacked = ( + torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + ) + else: + self.bias_stacked = None self.output_slices = ( self.q_proj_shard_size, @@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.lora_b_stacked[1][index] = 0 self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 + if self.lora_config.bias_enabled: + self.bias_stacked[0][index] = 0 + self.bias_stacked[1][index] = 0 + self.bias_stacked[2][index] = 0 def slice_lora_a( self, lora_a: List[Union[torch.Tensor, None]] @@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): lora_b = [lora_b_q, lora_b_k, lora_b_v] return lora_b + def slice_bias( + self, bias: List[Union[torch.Tensor, + None]]) -> List[Union[torch.Tensor, None]]: + bias_q, bias_k, bias_v = bias + if bias_q is not None: + bias_q = bias_q[self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + if bias_k is not None: + bias_k = bias_k[self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + if bias_v is not None: + bias_v = bias_v[self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + bias = [bias_q, bias_k, bias_v] + return bias + def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) + if bias is not None: + bias = self.slice_bias(bias) if lora_b[0] is not None: lora_b_q = lora_b[0] @@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) + if bias is not None: + if bias[0] is not None: + self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_( + bias[0].T, non_blocking=True) + if bias[1] is not None: + self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_( + bias[1].T, non_blocking=True) + if bias[2] is not None: + self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_( + bias[2].T, non_blocking=True) + def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + if self.bias_stacked is not None: + self.indices = self.punica_wrapper.token_lora_indices + output = apply_bias_packed_nslice( + self.indices, + output, + self.output_slices, + self.bias_stacked, + ) self.punica_wrapper.add_lora_packed_nslice(output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, @@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): device=self.device, ) + if lora_config.bias_enabled: + self.bias_stacked = torch.zeros( + ( + max_loras, + 1, + self.output_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + else: + self.bias_stacked = None + # Lazily initialized + self.indices: torch.Tensor + self.indices_len: List[int] + def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + if self.lora_config.bias_enabled: + self.bias_stacked[index] = 0 def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tensor_model_parallel_rank = get_tensor_model_parallel_rank() @@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: return lora_b + def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + return bias + def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) if self.base_layer.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) + if bias is not None: + bias = self.slice_bias(bias) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.lora_b_stacked[index, 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( lora_b.T, non_blocking=True) + if bias is not None: + self.bias_stacked[index, + 0, :bias.shape[0]].copy_(bias.T, + non_blocking=True) def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) + if self.bias_stacked is not None: + self.indices = self.punica_wrapper.token_lora_indices + output = apply_bias( + self.indices, + output, + self.bias_stacked, + ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0) return output @@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) self.lora_a_stacked[index, @@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): neginf=float("-inf"))) logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1], ] = lora_logits + lora_logits.shape[1]] = lora_logits # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, @@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], + bias: Optional[torch.Tensor] = None, ): ... diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 14081b5b..b648312b 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -17,6 +17,7 @@ class LoRALayerWeights: lora_alpha: int, lora_a: torch.Tensor, lora_b: torch.Tensor, + bias: Optional[torch.Tensor] = None, embeddings_tensor: Optional[torch.Tensor] = None, scaling: Optional[float] = None, ) -> None: @@ -25,6 +26,7 @@ class LoRALayerWeights: self.lora_alpha = lora_alpha self.lora_a = lora_a self.lora_b = lora_b + self.bias = bias self.embeddings_tensor = embeddings_tensor if scaling is None: @@ -66,7 +68,8 @@ class LoRALayerWeights: rank: int, dtype: torch.dtype, device: torch.types.Device, - embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + embeddings_tensor_dim: Optional[int] = None, + bias_enabled: Optional[bool] = False) -> "LoRALayerWeights": pin_memory = str(device) == "cpu" and is_pin_memory_available() lora_a = torch.zeros([input_dim, rank], dtype=dtype, @@ -76,6 +79,14 @@ class LoRALayerWeights: dtype=dtype, device=device, pin_memory=pin_memory) + if bias_enabled: + bias = torch.zeros([output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + else: + bias = None + embeddings_tensor = torch.rand( 10, embeddings_tensor_dim, @@ -88,6 +99,7 @@ class LoRALayerWeights: lora_alpha=1, lora_a=lora_a, lora_b=lora_b, + bias=bias, embeddings_tensor=embeddings_tensor, ) @@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): lora_alphas: List[Optional[int]], lora_a: List[Optional[torch.Tensor]], lora_b: List[Optional[torch.Tensor]], + bias: Optional[List[Optional[torch.Tensor]]] = None, scaling: Optional[List[float]] = None, ) -> None: super().__init__( @@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): lora_alpha=0, lora_a=lora_a, lora_b=lora_b, + bias=bias, scaling=scaling, # type: ignore embeddings_tensor=None, ) @@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights): [lora.lora_alpha if lora is not None else None for lora in loras], [lora.lora_a if lora is not None else None for lora in loras], [lora.lora_b if lora is not None else None for lora in loras], + [lora.bias if lora is not None else None for lora in loras], scaling=[ 1 if lora is not None else None # type: ignore for lora in loras diff --git a/vllm/lora/models.py b/vllm/lora/models.py index eafc3a43..2ffefe61 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,7 +4,7 @@ import math import os import re from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Sequence, Type import safetensors.torch import torch @@ -119,7 +119,8 @@ class LoRAModel(AdapterModel): pin_memory = str(device) == "cpu" and is_pin_memory_available() loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): - module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name( + tensor_name) if module_name not in loras: lora_embeddings_tensor = None if embeddings: @@ -136,8 +137,16 @@ class LoRAModel(AdapterModel): lora_embeddings_tensor.pin_memory()) loras[module_name] = LoRALayerWeights(module_name, rank, lora_alpha, None, None, + None, lora_embeddings_tensor) - if is_lora_a: + if is_bias: + loras[module_name].bias = tensor.to(device=device, + dtype=dtype).t() + bias = tensor.to(device=device, dtype=dtype).t() + if pin_memory: + bias = bias.pin_memory() + loras[module_name].bias = bias + elif is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() if pin_memory: @@ -215,7 +224,7 @@ class LoRAModel(AdapterModel): with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore for lora_module in f.keys(): # noqa - module_name, _ = parse_fine_tuned_lora_name(lora_module) + module_name, _, _ = parse_fine_tuned_lora_name(lora_module) part_name = module_name.split(".")[-1] if part_name not in expected_lora_modules: unexpected_modules.append(module_name) @@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager): module_lora = lora_model.get_lora(module_name) if module_lora: module_lora.optimize() + # Bias is not explicitly enabled with the flag enable_lora_bias. + bias = module_lora.bias + if ((torch.is_tensor(bias) or + (isinstance(bias, Sequence) and any(b is not None + for b in bias))) + and not self.lora_config.bias_enabled): + module_lora.bias = None + raise ValueError( + f"Adapter bias cannot be used for {module_name}" + " without --enable-lora-bias.") module.set_lora(index, module_lora.lora_a, module_lora.lora_b, - module_lora.embeddings_tensor) + module_lora.embeddings_tensor, + module_lora.bias) else: module.reset_lora(index) return True @@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager): """Create zero-initialized LoRAModel for warmup.""" model = LoRAModel(lora_id, rank, {}, scaling_factor) for module_name, module in self.model.named_modules(): + bias_enabled = self.lora_config.bias_enabled if (not self._match_target_modules(module_name) or not isinstance(module, BaseLayerWithLoRA) or isinstance(module, LinearScalingRotaryEmbeddingWithLora) @@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager): rank, module.lora_a_stacked.dtype, "cpu", - embeddings_tensor_dim=embeddings_tensor_dim) + embeddings_tensor_dim=embeddings_tensor_dim, + bias_enabled=bias_enabled) else: lora = LoRALayerWeights.create_dummy_lora_weights( module_name, @@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager): rank, module.lora_a_stacked.dtype, "cpu", + bias_enabled=bias_enabled, ) lora.optimize() else: @@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager): rank, module.lora_a_stacked[i].dtype, "cpu", + bias_enabled=bias_enabled, ) lora.optimize() subloras.append(lora) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index a780429f..5876494c 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str, return new_module -def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]: """Parse the name of lora weights. args: @@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: Tuple(module_name, is_lora_a): module_name: the name of the module, e.g. model.dense1, is_lora_a whether the tensor is lora_a or lora_b. + is_bias whether the tensor is lora bias. """ parts = name.split(".") + if parts[-1] == "weight" and (parts[-2] == "lora_A" + or parts[-2] == "lora_B"): + return ".".join(parts[2:-2]), parts[-2] == "lora_A", False - if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": - if parts[-1] == "weight": - if parts[-2] == "lora_A" or parts[-2] == "lora_B": - return ".".join(parts[2:-2]), parts[-2] == "lora_A" - elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": - return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False + + if parts[-1] == "bias": + return ".".join(parts[2:-2]), False, True raise ValueError(f"{name} is unsupported LoRA weight")