[LoRA] Adds support for bias in LoRA (#5733)

Signed-off-by: Umesh Deshpande <udeshpa@us.ibm.com>
Co-authored-by: Umesh Deshpande <udeshpa@us.ibm.com>
This commit is contained in:
Umesh 2024-11-12 11:08:40 -08:00 committed by GitHub
parent b41fb9d3b1
commit 8a06428c70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 456 additions and 20 deletions

View File

@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
return snapshot_download(repo_id=sql_lora_huggingface_id)
@pytest.fixture(scope="session")
def lora_bias_files():
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
@pytest.fixture(scope="session")
def mixtral_lora_files():
# Note: this module has incorrect adapter_config.json to test

View File

@ -0,0 +1,52 @@
from typing import List
import pytest
import vllm
from vllm.lora.request import LoRARequest
MODEL_PATH = "ibm-granite/granite-3b-code-base"
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
]
sampling_params = vllm.SamplingParams(temperature=0,
max_tokens=256,
stop=["[/assistant]"])
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
generated_texts: List[str] = []
for output in outputs:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
@pytest.mark.parametrize("lora_bias", [True])
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_lora_rank=8,
max_loras=1,
enable_lora_bias=lora_bias,
tensor_parallel_size=1,
fully_sharded_loras=fully_sharded)
print("lora adapter created")
output1 = do_sample(llm, lora_bias_files, lora_id=0)
print("lora")
output2 = do_sample(llm, lora_bias_files, lora_id=1)
if lora_bias:
assert output1 != output2
else:
assert output1 == output2

View File

@ -12,36 +12,40 @@ from vllm.utils import LRUCache
def test_parse_fine_tuned_lora_name_valid():
fixture = {
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
(
"base_model.model.model.embed_tokens.lora_embedding_A",
"model.embed_tokens",
True,
False,
),
(
"base_model.model.model.embed_tokens.lora_embedding_B",
"model.embed_tokens",
False,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
"model.layers.9.mlp.down_proj",
True,
False,
),
(
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
"model.layers.9.mlp.down_proj",
False,
False,
),
}
for name, module_name, is_lora_a in fixture:
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
for name, module_name, is_lora_a, is_bias in fixture:
assert (module_name, is_lora_a,
is_bias) == parse_fine_tuned_lora_name(name)
def test_parse_fine_tuned_lora_name_invalid():
fixture = {
"weight",
"base_model.weight",
"base_model.model.weight",
}

View File

@ -1687,6 +1687,7 @@ class LoRAConfig:
# This is a constant.
lora_vocab_padding_size: ClassVar[int] = 256
long_lora_scaling_factors: Optional[Tuple[float]] = None
bias_enabled: bool = False
def __post_init__(self):
# Setting the maximum rank to 256 should be able to satisfy the vast

View File

@ -143,6 +143,7 @@ class EngineArgs:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
enable_lora: bool = False
enable_lora_bias: bool = False
max_loras: int = 1
max_lora_rank: int = 16
enable_prompt_adapter: bool = False
@ -584,6 +585,9 @@ class EngineArgs:
parser.add_argument('--enable-lora',
action='store_true',
help='If True, enable handling of LoRA adapters.')
parser.add_argument('--enable-lora-bias',
action='store_true',
help='If True, enable bias for LoRA adapters.')
parser.add_argument('--max-loras',
type=int,
default=EngineArgs.max_loras,
@ -1148,6 +1152,7 @@ class EngineArgs:
and parallel_config.use_ray),
policy=self.scheduling_policy)
lora_config = LoRAConfig(
bias_enabled=self.enable_lora_bias,
max_lora_rank=self.max_lora_rank,
max_loras=self.max_loras,
fully_sharded_loras=self.fully_sharded_loras,

View File

@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked,
add_input=True)
# now have column partitioned output
if self.bias_stacked is not None:
self.bias_stacked = self.bias_stacked.view(
-1, self.bias_stacked.shape[-1])
self.bias_stacked = self.bias_stacked[
self.punica_wrapper.token_lora_indices]
output += self.bias_stacked
output = output.view(*out_orig_shape)
return output
@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]
if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias
layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
shard_size = self.bias_stacked.shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size
if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias
self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
shard_size)

View File

@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace):
return dec
def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked
return output.view_as(org_output)
def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module):
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
"""Overwrites lora tensors at index."""
...
@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
max_loras,
1,
self.output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def set_lora(
self,
@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
)
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
max_loras,
1,
self.output_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
self.output_dim = self.lora_b_stacked.shape[2]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
shard_size = self.output_dim
start_idx = tensor_model_parallel_rank * shard_size
end_idx = (tensor_model_parallel_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
bias = self.slice_bias(bias)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(n_slices))
if lora_config.bias_enabled:
self.bias_stacked = tuple(
torch.zeros(
max_loras,
1,
self.output_size // 2,
dtype=lora_config.lora_dtype,
device=self.device,
) for _ in range(n_slices))
else:
self.bias_stacked = None
self.output_dim = self.lora_b_stacked[0].shape[2]
@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_a_stacked[1][index] = 0
self.lora_b_stacked[0][index] = 0
self.lora_b_stacked[1][index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[0][index] = 0
self.bias_stacked[1][index] = 0
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
if bias[0] is None or bias[1] is None:
return bias
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
if lora_a[0] is not None:
self.lora_a_stacked[0][
@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[0][
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
lora_b[0].T, non_blocking=True)
if bias is not None and bias[0] is not None:
self.bias_stacked[0][index,
0, :bias[0].shape[0]].copy_(bias[0].T,
non_blocking=True)
if lora_a[1] is not None:
self.lora_a_stacked[1][
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[1][
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
lora_b[1].T, non_blocking=True)
if bias is not None and bias[1] is not None:
self.bias_stacked[1][index,
0, :bias[1].shape[0]].copy_(bias[1].T,
non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim))
@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
bias_q = bias[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
k_offset = self.q_proj_total_size
bias_k = bias[k_offset +
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
v_offset = k_offset + self.kv_proj_total_size
bias_v = bias[v_offset +
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
@classmethod
@_not_fully_sharded_can_replace
@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
device=self.device,
),
)
if lora_config.bias_enabled:
self.bias_stacked = (
torch.zeros(
max_loras,
1,
self.q_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
torch.zeros(
max_loras,
1,
self.kv_proj_shard_size,
dtype=lora_config.lora_dtype,
device=self.device,
),
)
else:
self.bias_stacked = None
self.output_slices = (
self.q_proj_shard_size,
@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self.lora_b_stacked[1][index] = 0
self.lora_a_stacked[2][index] = 0
self.lora_b_stacked[2][index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[0][index] = 0
self.bias_stacked[1][index] = 0
self.bias_stacked[2][index] = 0
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
lora_b = [lora_b_q, lora_b_k, lora_b_v]
return lora_b
def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
bias_q, bias_k, bias_v = bias
if bias_q is not None:
bias_q = bias_q[self.q_proj_shard_size *
self.q_shard_id:self.q_proj_shard_size *
(self.q_shard_id + 1)]
if bias_k is not None:
bias_k = bias_k[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
if bias_v is not None:
bias_v = bias_v[self.kv_proj_shard_size *
self.kv_shard_id:self.kv_proj_shard_size *
(self.kv_shard_id + 1)]
bias = [bias_q, bias_k, bias_v]
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
if lora_b[0] is not None:
lora_b_q = lora_b[0]
@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
lora_a[2].T, non_blocking=True)
if bias is not None:
if bias[0] is not None:
self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
bias[0].T, non_blocking=True)
if bias[1] is not None:
self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
bias[1].T, non_blocking=True)
if bias[2] is not None:
self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
bias[2].T, non_blocking=True)
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked,
self.lora_b_stacked, 1.0,
@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
device=self.device,
)
if lora_config.bias_enabled:
self.bias_stacked = torch.zeros(
(
max_loras,
1,
self.output_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
else:
self.bias_stacked = None
# Lazily initialized
self.indices: torch.Tensor
self.indices_len: List[int]
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
if self.lora_config.bias_enabled:
self.bias_stacked[index] = 0
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
return bias
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
if self.base_layer.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
if bias is not None:
bias = self.slice_bias(bias)
self.lora_a_stacked[index,
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.lora_b_stacked[index,
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
lora_b.T, non_blocking=True)
if bias is not None:
self.bias_stacked[index,
0, :bias.shape[0]].copy_(bias.T,
non_blocking=True)
def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
return output
@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
self.reset_lora(index)
self.lora_a_stacked[index,
@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
neginf=float("-inf")))
logits[:,
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
lora_logits.shape[1], ] = lora_logits
lora_logits.shape[1]] = lora_logits
# LogitsProcessorWithLoRA always using bgmv
self.punica_wrapper.add_lora_logits(logits, hidden_states,
@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: Optional[torch.Tensor],
bias: Optional[torch.Tensor] = None,
):
...

View File

@ -17,6 +17,7 @@ class LoRALayerWeights:
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
@ -25,6 +26,7 @@ class LoRALayerWeights:
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor
if scaling is None:
@ -66,7 +68,8 @@ class LoRALayerWeights:
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
@ -76,6 +79,14 @@ class LoRALayerWeights:
dtype=dtype,
device=device,
pin_memory=pin_memory)
if bias_enabled:
bias = torch.zeros([output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
else:
bias = None
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
@ -88,6 +99,7 @@ class LoRALayerWeights:
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor,
)
@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alphas: List[Optional[int]],
lora_a: List[Optional[torch.Tensor]],
lora_b: List[Optional[torch.Tensor]],
bias: Optional[List[Optional[torch.Tensor]]] = None,
scaling: Optional[List[float]] = None,
) -> None:
super().__init__(
@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras

View File

@ -4,7 +4,7 @@ import math
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
import safetensors.torch
import torch
@ -119,7 +119,8 @@ class LoRAModel(AdapterModel):
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: Dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
@ -136,8 +137,16 @@ class LoRAModel(AdapterModel):
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights(module_name, rank,
lora_alpha, None, None,
None,
lora_embeddings_tensor)
if is_lora_a:
if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
@ -215,7 +224,7 @@ class LoRAModel(AdapterModel):
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module)
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager):
module_lora = lora_model.get_lora(module_name)
if module_lora:
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if ((torch.is_tensor(bias) or
(isinstance(bias, Sequence) and any(b is not None
for b in bias)))
and not self.lora_config.bias_enabled):
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias.")
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor)
module_lora.embeddings_tensor,
module_lora.bias)
else:
module.reset_lora(index)
return True
@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager):
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager):
rank,
module.lora_a_stacked.dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim)
embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager):
rank,
module.lora_a_stacked.dtype,
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
else:
@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager):
rank,
module.lora_a_stacked[i].dtype,
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
subloras.append(lora)

View File

@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
return new_module
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
"""Parse the name of lora weights.
args:
@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
Tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
"""
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
if parts[-1] == "weight":
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
if parts[-1] == "bias":
return ".".join(parts[2:-2]), False, True
raise ValueError(f"{name} is unsupported LoRA weight")