[LoRA] Adds support for bias in LoRA (#5733)
Signed-off-by: Umesh Deshpande <udeshpa@us.ibm.com> Co-authored-by: Umesh Deshpande <udeshpa@us.ibm.com>
This commit is contained in:
parent
b41fb9d3b1
commit
8a06428c70
@ -152,6 +152,11 @@ def sql_lora_files(sql_lora_huggingface_id):
|
||||
return snapshot_download(repo_id=sql_lora_huggingface_id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lora_bias_files():
|
||||
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mixtral_lora_files():
|
||||
# Note: this module has incorrect adapter_config.json to test
|
||||
|
||||
52
tests/lora/test_lora_bias_e2e.py
Normal file
52
tests/lora/test_lora_bias_e2e.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
import vllm
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_PATH = "ibm-granite/granite-3b-code-base"
|
||||
|
||||
|
||||
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
prompts = [
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501
|
||||
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501
|
||||
]
|
||||
sampling_params = vllm.SamplingParams(temperature=0,
|
||||
max_tokens=256,
|
||||
stop=["[/assistant]"])
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
|
||||
if lora_id else None)
|
||||
generated_texts: List[str] = []
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
generated_texts.append(generated_text)
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lora_bias", [True])
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool):
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_lora_rank=8,
|
||||
max_loras=1,
|
||||
enable_lora_bias=lora_bias,
|
||||
tensor_parallel_size=1,
|
||||
fully_sharded_loras=fully_sharded)
|
||||
|
||||
print("lora adapter created")
|
||||
output1 = do_sample(llm, lora_bias_files, lora_id=0)
|
||||
|
||||
print("lora")
|
||||
output2 = do_sample(llm, lora_bias_files, lora_id=1)
|
||||
|
||||
if lora_bias:
|
||||
assert output1 != output2
|
||||
else:
|
||||
assert output1 == output2
|
||||
@ -12,36 +12,40 @@ from vllm.utils import LRUCache
|
||||
|
||||
def test_parse_fine_tuned_lora_name_valid():
|
||||
fixture = {
|
||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
|
||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
|
||||
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
|
||||
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
|
||||
(
|
||||
"base_model.model.model.embed_tokens.lora_embedding_A",
|
||||
"model.embed_tokens",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"base_model.model.model.embed_tokens.lora_embedding_B",
|
||||
"model.embed_tokens",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
||||
"model.layers.9.mlp.down_proj",
|
||||
True,
|
||||
False,
|
||||
),
|
||||
(
|
||||
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
||||
"model.layers.9.mlp.down_proj",
|
||||
False,
|
||||
False,
|
||||
),
|
||||
}
|
||||
for name, module_name, is_lora_a in fixture:
|
||||
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
|
||||
for name, module_name, is_lora_a, is_bias in fixture:
|
||||
assert (module_name, is_lora_a,
|
||||
is_bias) == parse_fine_tuned_lora_name(name)
|
||||
|
||||
|
||||
def test_parse_fine_tuned_lora_name_invalid():
|
||||
fixture = {
|
||||
"weight",
|
||||
"base_model.weight",
|
||||
"base_model.model.weight",
|
||||
}
|
||||
|
||||
@ -1687,6 +1687,7 @@ class LoRAConfig:
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
long_lora_scaling_factors: Optional[Tuple[float]] = None
|
||||
bias_enabled: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# Setting the maximum rank to 256 should be able to satisfy the vast
|
||||
|
||||
@ -143,6 +143,7 @@ class EngineArgs:
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
enable_prompt_adapter: bool = False
|
||||
@ -584,6 +585,9 @@ class EngineArgs:
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
parser.add_argument('--enable-lora-bias',
|
||||
action='store_true',
|
||||
help='If True, enable bias for LoRA adapters.')
|
||||
parser.add_argument('--max-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_loras,
|
||||
@ -1148,6 +1152,7 @@ class EngineArgs:
|
||||
and parallel_config.use_ray),
|
||||
policy=self.scheduling_policy)
|
||||
lora_config = LoRAConfig(
|
||||
bias_enabled=self.enable_lora_bias,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
fully_sharded_loras=self.fully_sharded_loras,
|
||||
|
||||
@ -70,6 +70,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
self.bias_stacked = self.bias_stacked.view(
|
||||
-1, self.bias_stacked.shape[-1])
|
||||
self.bias_stacked = self.bias_stacked[
|
||||
self.punica_wrapper.token_lora_indices]
|
||||
output += self.bias_stacked
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@ -121,6 +129,15 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
|
||||
if layer.bias_stacked is not None:
|
||||
bias = layer.bias_stacked[idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[layer.punica_wrapper.token_lora_indices]
|
||||
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output[:, left_offset:left_offset + shard_size] += bias
|
||||
|
||||
layer.punica_wrapper.add_expand_slice(
|
||||
output,
|
||||
buffers[idx],
|
||||
@ -295,6 +312,15 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
if bias is None:
|
||||
return bias
|
||||
shard_size = self.bias_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
bias = bias[start_idx:end_idx]
|
||||
return bias
|
||||
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
|
||||
@ -318,6 +344,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
||||
bias = bias[self.punica_wrapper.token_lora_indices]
|
||||
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output += bias
|
||||
|
||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||
self.lora_b_stacked, start_idx,
|
||||
shard_size)
|
||||
|
||||
@ -67,6 +67,63 @@ def _not_fully_sharded_can_replace(can_replace):
|
||||
return dec
|
||||
|
||||
|
||||
def apply_bias(
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
bias_stacked: torch.Tensor,
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
||||
bias_stacked = bias_stacked[indices]
|
||||
bias_stacked[indices == -1] = 0
|
||||
output += bias_stacked
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
def apply_bias_packed_nslice(
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...),
|
||||
where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
offset_left = 0
|
||||
for slice_idx, slice in enumerate(output_slices):
|
||||
bias = bias_stacked[slice_idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[indices]
|
||||
bias[indices == -1] = 0
|
||||
output[:, offset_left:offset_left + slice] += bias
|
||||
|
||||
offset_left += slice
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping(AdapterMapping):
|
||||
is_prefill: bool = False
|
||||
@ -105,6 +162,7 @@ class BaseLayerWithLoRA(nn.Module):
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
@ -203,6 +261,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
@ -299,10 +358,22 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if lora_config.bias_enabled:
|
||||
self.bias_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
self.bias_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
@ -310,6 +381,7 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
@ -319,10 +391,21 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if bias is not None:
|
||||
self.bias_stacked[index,
|
||||
0, :bias.shape[0]].copy_(bias.T,
|
||||
non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
@ -401,11 +484,25 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if lora_config.bias_enabled:
|
||||
self.bias_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
|
||||
self.output_dim = self.lora_b_stacked.shape[2]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
self.bias_stacked[index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
return lora_a
|
||||
@ -418,18 +515,30 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
if bias is None:
|
||||
return bias
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
bias = bias[start_idx:end_idx]
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
bias = self.slice_bias(bias)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
@ -437,10 +546,21 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if bias is not None:
|
||||
self.bias_stacked[index,
|
||||
0, :bias.shape[0]].copy_(bias.T,
|
||||
non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
@ -534,6 +654,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(n_slices))
|
||||
if lora_config.bias_enabled:
|
||||
self.bias_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size // 2,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(n_slices))
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
|
||||
@ -542,6 +673,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_a_stacked[1][index] = 0
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
self.bias_stacked[0][index] = 0
|
||||
self.bias_stacked[1][index] = 0
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
@ -562,18 +696,32 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(
|
||||
self, bias: List[Union[torch.Tensor,
|
||||
None]]) -> List[Union[torch.Tensor, None]]:
|
||||
if bias[0] is None or bias[1] is None:
|
||||
return bias
|
||||
shard_size = self.output_dim
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if bias is not None:
|
||||
bias = self.slice_bias(bias)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
@ -582,6 +730,10 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if bias is not None and bias[0] is not None:
|
||||
self.bias_stacked[0][index,
|
||||
0, :bias[0].shape[0]].copy_(bias[0].T,
|
||||
non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
@ -589,10 +741,22 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
if bias is not None and bias[1] is not None:
|
||||
self.bias_stacked[1][index,
|
||||
0, :bias[1].shape[0]].copy_(bias[1].T,
|
||||
non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias_packed_nslice(
|
||||
self.indices,
|
||||
output,
|
||||
(self.output_dim, self.output_dim),
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
||||
(self.output_dim, self.output_dim))
|
||||
@ -654,17 +818,35 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1)
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
bias_q = bias[self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
k_offset = self.q_proj_total_size
|
||||
bias_k = bias[k_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:k_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
bias_v = bias[v_offset +
|
||||
self.kv_proj_shard_size * self.kv_shard_id:v_offset +
|
||||
self.kv_proj_shard_size * (self.kv_shard_id + 1)]
|
||||
bias = torch.cat([bias_q, bias_k, bias_v], dim=1)
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if bias is not None:
|
||||
bias = self.slice_bias(bias)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
@ -672,6 +854,10 @@ class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if bias is not None:
|
||||
self.bias_stacked[index,
|
||||
0, :bias.shape[0]].copy_(bias.T,
|
||||
non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
@ -768,6 +954,32 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
if lora_config.bias_enabled:
|
||||
self.bias_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.q_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
@ -787,6 +999,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
self.lora_a_stacked[2][index] = 0
|
||||
self.lora_b_stacked[2][index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
self.bias_stacked[0][index] = 0
|
||||
self.bias_stacked[1][index] = 0
|
||||
self.bias_stacked[2][index] = 0
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
@ -812,18 +1028,40 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(
|
||||
self, bias: List[Union[torch.Tensor,
|
||||
None]]) -> List[Union[torch.Tensor, None]]:
|
||||
bias_q, bias_k, bias_v = bias
|
||||
if bias_q is not None:
|
||||
bias_q = bias_q[self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
if bias_k is not None:
|
||||
bias_k = bias_k[self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
if bias_v is not None:
|
||||
bias_v = bias_v[self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
bias = [bias_q, bias_k, bias_v]
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if bias is not None:
|
||||
bias = self.slice_bias(bias)
|
||||
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0]
|
||||
@ -854,9 +1092,28 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||
lora_a[2].T, non_blocking=True)
|
||||
|
||||
if bias is not None:
|
||||
if bias[0] is not None:
|
||||
self.bias_stacked[0][index, 0, :bias[0].shape[0]].copy_(
|
||||
bias[0].T, non_blocking=True)
|
||||
if bias[1] is not None:
|
||||
self.bias_stacked[1][index, 0, :bias[1].shape[0]].copy_(
|
||||
bias[1].T, non_blocking=True)
|
||||
if bias[2] is not None:
|
||||
self.bias_stacked[2][index, 0, :bias[2].shape[0]].copy_(
|
||||
bias[2].T, non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias_packed_nslice(
|
||||
self.indices,
|
||||
output,
|
||||
self.output_slices,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0,
|
||||
@ -919,9 +1176,27 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if lora_config.bias_enabled:
|
||||
self.bias_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.output_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
# Lazily initialized
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
self.bias_stacked[index] = 0
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
@ -934,18 +1209,24 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
return bias
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.base_layer.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if bias is not None:
|
||||
bias = self.slice_bias(bias)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
@ -953,9 +1234,20 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if bias is not None:
|
||||
self.bias_stacked[index,
|
||||
0, :bias.shape[0]].copy_(bias.T,
|
||||
non_blocking=True)
|
||||
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
@ -1132,6 +1424,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index,
|
||||
@ -1199,7 +1492,7 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
neginf=float("-inf")))
|
||||
logits[:,
|
||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||
lora_logits.shape[1], ] = lora_logits
|
||||
lora_logits.shape[1]] = lora_logits
|
||||
|
||||
# LogitsProcessorWithLoRA always using bgmv
|
||||
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
||||
@ -1276,6 +1569,7 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ class LoRALayerWeights:
|
||||
lora_alpha: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
scaling: Optional[float] = None,
|
||||
) -> None:
|
||||
@ -25,6 +26,7 @@ class LoRALayerWeights:
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_a = lora_a
|
||||
self.lora_b = lora_b
|
||||
self.bias = bias
|
||||
self.embeddings_tensor = embeddings_tensor
|
||||
|
||||
if scaling is None:
|
||||
@ -66,7 +68,8 @@ class LoRALayerWeights:
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.types.Device,
|
||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||
embeddings_tensor_dim: Optional[int] = None,
|
||||
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
dtype=dtype,
|
||||
@ -76,6 +79,14 @@ class LoRALayerWeights:
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
if bias_enabled:
|
||||
bias = torch.zeros([output_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
embeddings_tensor = torch.rand(
|
||||
10,
|
||||
embeddings_tensor_dim,
|
||||
@ -88,6 +99,7 @@ class LoRALayerWeights:
|
||||
lora_alpha=1,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
bias=bias,
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
|
||||
@ -102,6 +114,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
lora_alphas: List[Optional[int]],
|
||||
lora_a: List[Optional[torch.Tensor]],
|
||||
lora_b: List[Optional[torch.Tensor]],
|
||||
bias: Optional[List[Optional[torch.Tensor]]] = None,
|
||||
scaling: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -110,6 +123,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
lora_alpha=0,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
bias=bias,
|
||||
scaling=scaling, # type: ignore
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
@ -141,6 +155,7 @@ class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||
[lora.lora_a if lora is not None else None for lora in loras],
|
||||
[lora.lora_b if lora is not None else None for lora in loras],
|
||||
[lora.bias if lora is not None else None for lora in loras],
|
||||
scaling=[
|
||||
1 if lora is not None else None # type: ignore
|
||||
for lora in loras
|
||||
|
||||
@ -4,7 +4,7 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -119,7 +119,8 @@ class LoRAModel(AdapterModel):
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
||||
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
||||
tensor_name)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
@ -136,8 +137,16 @@ class LoRAModel(AdapterModel):
|
||||
lora_embeddings_tensor.pin_memory())
|
||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||
lora_alpha, None, None,
|
||||
None,
|
||||
lora_embeddings_tensor)
|
||||
if is_lora_a:
|
||||
if is_bias:
|
||||
loras[module_name].bias = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
bias = tensor.to(device=device, dtype=dtype).t()
|
||||
if pin_memory:
|
||||
bias = bias.pin_memory()
|
||||
loras[module_name].bias = bias
|
||||
elif is_lora_a:
|
||||
loras[module_name].lora_a = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if pin_memory:
|
||||
@ -215,7 +224,7 @@ class LoRAModel(AdapterModel):
|
||||
with safetensors.safe_open(lora_tensor_path,
|
||||
framework="pt") as f: # type: ignore
|
||||
for lora_module in f.keys(): # noqa
|
||||
module_name, _ = parse_fine_tuned_lora_name(lora_module)
|
||||
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
|
||||
part_name = module_name.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
unexpected_modules.append(module_name)
|
||||
@ -386,8 +395,19 @@ class LoRAModelManager(AdapterModelManager):
|
||||
module_lora = lora_model.get_lora(module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
||||
bias = module_lora.bias
|
||||
if ((torch.is_tensor(bias) or
|
||||
(isinstance(bias, Sequence) and any(b is not None
|
||||
for b in bias)))
|
||||
and not self.lora_config.bias_enabled):
|
||||
module_lora.bias = None
|
||||
raise ValueError(
|
||||
f"Adapter bias cannot be used for {module_name}"
|
||||
" without --enable-lora-bias.")
|
||||
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||
module_lora.embeddings_tensor)
|
||||
module_lora.embeddings_tensor,
|
||||
module_lora.bias)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
@ -509,6 +529,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||
for module_name, module in self.model.named_modules():
|
||||
bias_enabled = self.lora_config.bias_enabled
|
||||
if (not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
||||
@ -536,7 +557,8 @@ class LoRAModelManager(AdapterModelManager):
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
embeddings_tensor_dim=embeddings_tensor_dim)
|
||||
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||
bias_enabled=bias_enabled)
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
@ -545,6 +567,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
else:
|
||||
@ -559,6 +582,7 @@ class LoRAModelManager(AdapterModelManager):
|
||||
rank,
|
||||
module.lora_a_stacked[i].dtype,
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
subloras.append(lora)
|
||||
|
||||
@ -91,7 +91,7 @@ def replace_submodule(model: nn.Module, module_name: str,
|
||||
return new_module
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
args:
|
||||
@ -101,15 +101,18 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||
Tuple(module_name, is_lora_a):
|
||||
module_name: the name of the module, e.g. model.dense1,
|
||||
is_lora_a whether the tensor is lora_a or lora_b.
|
||||
is_bias whether the tensor is lora bias.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||
or parts[-2] == "lora_B"):
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
|
||||
|
||||
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
|
||||
if parts[-1] == "weight":
|
||||
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
||||
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
|
||||
|
||||
if parts[-1] == "bias":
|
||||
return ".".join(parts[2:-2]), False, True
|
||||
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user