diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu index c642e949..86846c27 100644 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu index 0607cebf..de39c312 100644 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index fec484d6..19c058ca 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py +// Used for defining kernels going from the variety of +// dim in to the narrow dim out + // Using it for the fully sharded column + // parallel LoRA A which splits the rank dim +#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, 128, narrow) \ + f(in_T, out_T, W_T, 256, narrow) \ + f(in_T, out_T, W_T, 512, narrow) \ + f(in_T, out_T, W_T, 640, narrow) \ + f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 1024, narrow) \ + f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1280, narrow) \ + f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1728, narrow) \ + f(in_T, out_T, W_T, 1792, narrow) \ + f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2560, narrow) \ + f(in_T, out_T, W_T, 2752, narrow) \ + f(in_T, out_T, W_T, 2816, narrow) \ + f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3456, narrow) \ + f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 5120, narrow) \ + f(in_T, out_T, W_T, 5504, narrow) \ + f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6848, narrow) \ + f(in_T, out_T, W_T, 6912, narrow) \ + f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 10240, narrow) \ + f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 12288, narrow) \ + f(in_T, out_T, W_T, 13696, narrow) \ + f(in_T, out_T, W_T, 13824, narrow) \ + f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 15360, narrow) \ + f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 20480, narrow) \ + f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 24576, narrow) \ + f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 32000, narrow) \ + f(in_T, out_T, W_T, 32256, narrow) \ + f(in_T, out_T, W_T, 32512, narrow) \ + f(in_T, out_T, W_T, 32768, narrow) \ + f(in_T, out_T, W_T, 33024, narrow) \ + f(in_T, out_T, W_T, 36864, narrow) \ + f(in_T, out_T, W_T, 43264, narrow) \ + f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 64000, narrow) \ + f(in_T, out_T, W_T, 64256, narrow) \ + f(in_T, out_T, W_T, 64512, narrow) \ + f(in_T, out_T, W_T, 102400, narrow) \ + f(in_T, out_T, W_T, 102656, narrow) \ + f(in_T, out_T, W_T, 102912, narrow) \ + f(in_T, out_T, W_T, 128000, narrow) \ + f(in_T, out_T, W_T, 128256, narrow) \ + f(in_T, out_T, W_T, 128512, narrow) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + + // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ @@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ + f(in_T, out_T, W_T, 8, 64) \ + f(in_T, out_T, W_T, 16, 64) \ + f(in_T, out_T, W_T, 32, 64) \ + f(in_T, out_T, W_T, 64, 64) + // clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu index f1db6df5..d225a1ea 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu index c01ddd00..b37d288a 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu index f45183ff..a1ab2dee 100644 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu index 40977434..0b35bf56 100644 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 995de26e..dad8805c 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, constexpr int tz = 4; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if constexpr (feat_in < feat_out) { + if constexpr (feat_in <= feat_out) { static_assert(feat_in % vec_size == 0); constexpr int tx = feat_in / vec_size; @@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t num_layers, int64_t layer_idx, float scale); +#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ + INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) + #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index 9bf7f635..972df5a7 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -10,6 +10,7 @@ TEMPLATE = """ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) """.lstrip() # noqa: E501 for input_dtype in DTYPES: diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index a1eaa90e..8797fde8 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) FOR_BGMV_WIDE_NARROW(CASE, _, _, _) + FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) #undef CASE #undef CASE_ONESIDE default: return false; } - return true; } diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 1616fdfd..0eb04f4c 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,6 +8,10 @@ import torch import torch.nn.functional as F from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, @@ -524,13 +528,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_random_linear_parallel_layer(): @@ -540,14 +547,17 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = RowParallelLinearWithLoRA(linear) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) else: linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -629,13 +639,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_column_parallel_packed_layer(): @@ -644,7 +657,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedColumnParallelLinearWithLoRA(linear) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) elif repeats == 3: linear = QKVParallelLinear(4096, 64, @@ -652,7 +667,9 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedQKVParallelLinearWithLora(linear) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) else: linear = QKVParallelLinear(4096, 64, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f3b9bd59..fd2a1b75 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -34,11 +34,14 @@ def _lora_ref_impl( for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): xi = x[i].unsqueeze(0).to(torch.float32) wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + if wb_T_all is not None: + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, + -2).to(torch.float32) tmp = xi @ wa y_stage_1[i] = tmp.squeeze(0) - y_final[i] += (tmp @ wb).squeeze(0) * s + y_final[i] += ((tmp @ wb).squeeze(0) * + s if wb_T_all is not None else y_stage_1[i]) return y_final, y_stage_1 @@ -91,12 +94,56 @@ H1 = H2 = [ 128000, 128256, ] +H2 = [64] + H2 +R = [1, 2, 4] SEED = [0xabcdabcd987] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("r", R) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_a_extra_shapes(dtype_str, h1, r, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + bs = 32 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, r, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl( + y_ref, + x, + wa_T_all, + None, + indices, + layer_idx, + 1.0, + ) + + y_our = y.clone() + punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) + + assert_close(y_ref, y_our) + + @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) diff --git a/vllm/config.py b/vllm/config.py index 887a73d9..aedb5892 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -862,6 +862,7 @@ class SpeculativeConfig: class LoRAConfig: max_lora_rank: int max_loras: int + fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6a6ac49a..bd6437ee 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -52,6 +52,7 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None @@ -376,6 +377,14 @@ class EngineArgs: help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -509,6 +518,7 @@ class EngineArgs: lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 00000000..17205668 --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,262 @@ +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA) +from vllm.lora.punica import bgmv, dispatch_bgmv_low_level + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs['lora_config'].fully_sharded_loras) + + return dec + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + bgmv(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + # now have column partitioned output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +def _mcp_apply_weights(x, bias, layer): + """ + MergedColumnParallelLinearWithShardedLoRA and + QKVParallelLinearWithShardedLora share the same + LoRa weight application method. + + The main difference is the step by shard_size for lora_b which can + vary for QKVParallelLinearWithShardedLora but is constant for + MergedColumnParallelLinearWithShardedLoRA. + """ + # expecting 2 for column parallel and 3 for qkv + n = len(layer.lora_a_stacked) + output = layer.base_layer.linear_method.apply_weights( + layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device) + for idx in range(n): + bgmv(buffers[idx], x, layer.lora_a_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0) + + buffers = tensor_model_parallel_all_gather(buffers) + left_offset = 0 + for idx in range(n): + shard_size = layer.lora_b_stacked[idx].shape[2] + dispatch_bgmv_low_level(output, buffers[idx], + layer.lora_b_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0, + left_offset, shard_size) + left_offset += shard_size + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] + for i in range(2) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[i][:, start_idx[i]:start_idx[i] + + shard_size[i]] if lora_a[i] is not None else None + for i in range(3) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0, + start_idx, shard_size) + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 4eaf73fb..b3609666 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,8 +1,7 @@ # pylint: disable=unused-argument -import inspect import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_gather) +from vllm.distributed.utils import divide from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) if TYPE_CHECKING: pass @@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: raise ValueError(f"Unsupported base layer: {base_layer}") +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True + condition = (not kwargs['lora_config'].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -130,6 +145,14 @@ class LoRAMapping: class BaseLayerWithLoRA(nn.Module): + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + """Slice lora b if splitting with tensor parallelism.""" + ... + def create_lora_weights( self, max_loras: int, @@ -317,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + + LoRA B is sliced for tensor parallelism. + """ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() @@ -331,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -357,6 +390,17 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + def set_lora( self, index: int, @@ -365,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) @@ -426,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): return output, output_bias @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -451,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices and self.base_layer.output_sizes[0] @@ -459,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): "LoRAColumnParallelLinear2Slice requires 2 slices with " "the same size.") self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -489,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = [ + lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + ] + return lora_b + def set_lora( self, index: int, @@ -499,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): self.reset_lora(index) if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[0][:, - start_idx:end_idx], lora_b[1][:, - start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -536,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -627,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.base_layer.head_size) - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) # q, k, v self.lora_a_stacked = ( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -649,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -657,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -705,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + lora_b = [lora_b_q, lora_b_k, lora_b_v] + return lora_b + def set_lora( self, index: int, @@ -715,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): self.reset_lora(index) if self.tp_size > 1: - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - else: - if lora_b[0] is not None: - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_b[1] is not None: - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_b[2] is not None: - self.lora_b_stacked[2][ - index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( - lora_b[2].T, non_blocking=True) + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_b[0] is not None: + lora_b_q = lora_b[0] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -777,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -798,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( ( max_loras, @@ -808,11 +876,16 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): dtype=lora_config.lora_dtype, device=self.device, ) + tp_size = get_tensor_model_parallel_world_size() + lora_b_output_size_per_partition = ( + self.output_size if not lora_config.fully_sharded_loras else + divide(self.output_size, tp_size)) + self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.output_size, + lora_b_output_size_per_partition, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -826,6 +899,17 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + def set_lora( self, index: int, @@ -834,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.base_layer.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.input_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -915,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): self.base_layer, "weight") else self.base_layer.qweight @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -1096,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): model_config: Optional[PretrainedConfig]) -> bool: # Special handling for the LogitsProcessor. return False - - -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - cls - for cls in globals().values() if inspect.isclass(cls) - and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA -} - - -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: List, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: - for lora_cls in _all_lora_classes: - if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list, - model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret - return layer - - -def from_layer_logits_processor( - layer: LogitsProcessor, - lm_head: ParallelLMHead, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, -) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6a077e9b..50d7e913 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -11,10 +11,10 @@ from torch import nn from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_logits_processor) +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index fc74269e..c87bed54 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -49,6 +49,49 @@ def bgmv( punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) +def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, indicies: torch.LongTensor, + layer_idx: int, scale: float, y_offset: int, + y_slice_size: int): + """ + Same as `bgmv` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of + all of the transposed LoRA matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + punica_kernels.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + x.size(1), + y_slice_size, + y_offset, + ) + + def add_lora(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 39e08f04..9942a5fd 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,11 +1,69 @@ -from typing import Tuple +from typing import List, Optional, Set, Tuple, Type from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + def replace_submodule(model: nn.Module, module_name: str, new_module: nn.Module) -> nn.Module: