From 1e96c3341a4e055ae392085fecc7a672295b71c2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 11 Apr 2024 15:18:57 -0700 Subject: [PATCH] Add extra punica sizes to support bigger vocabs (#4015) --- csrc/punica/bgmv/bgmv_config.h | 12 +++++- csrc/punica/punica_ops.cc | 14 +++--- tests/lora/test_layers.py | 78 +++++++++++++++++++--------------- tests/lora/test_punica.py | 49 +++++++++++++++++++-- vllm/lora/layers.py | 4 +- 5 files changed, 109 insertions(+), 48 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 1084a0f2..9b76b98a 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -60,7 +60,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + f(in_T, out_T, W_T, narrow, 64000) \ + f(in_T, out_T, W_T, narrow, 64256) \ + f(in_T, out_T, W_T, narrow, 64512) \ + f(in_T, out_T, W_T, narrow, 102400) \ + f(in_T, out_T, W_T, narrow, 102656) \ + f(in_T, out_T, W_T, narrow, 102912) \ + f(in_T, out_T, W_T, narrow, 128000) \ + f(in_T, out_T, W_T, narrow, 128256) \ + f(in_T, out_T, W_T, narrow, 128512) \ +// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA +// and vllm/tests/lora/test_punica.py // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index 28739be1..7ebfd851 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -20,8 +20,8 @@ inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) { + return (uint64_t(a) << 32) | uint64_t(b); } #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") @@ -46,13 +46,13 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { template inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, const int64_t *lora_indices, - uint16_t in_features, uint16_t out_features, + uint32_t in_features, uint32_t out_features, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { - switch (pack_u16(in_features, out_features)) { + switch (pack_u32(in_features, out_features)) { #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ - case pack_u16(feat_in, feat_out): \ + case pack_u32(feat_in, feat_out): \ bgmv_kernel(Y, X, W, lora_indices, y_offset, \ full_y_size, batch_size, num_layers, \ layer_idx, scale); \ @@ -93,7 +93,7 @@ void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: @@ -325,7 +325,7 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, CHECK_EQ(y.size(0), x.size(0)); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); bool ok = false; - if (h_in < 65536 && h_out < 65536) { + if (h_in <= 128512 && h_out <= 128512) { // TODO: See if we can get rid of this massive nested switch switch (x.scalar_type()) { case at::ScalarType::Half: diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71ce6f17..e9e0c855 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,8 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings(dist_init, num_loras, device, vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -179,9 +180,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -203,12 +204,13 @@ def test_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info) lora_result = lora_embedding(torch.cat(inputs)) @@ -240,12 +242,13 @@ def test_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(inputs)) @@ -263,7 +266,9 @@ def test_embeddings(dist_init, num_loras, device) -> None: # reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -272,15 +277,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def create_random_embedding_layer(): - embedding = VocabParallelEmbedding(512, 256) + embedding = VocabParallelEmbedding(vocab_size, 256) embedding_data = torch.rand_like(embedding.weight.data) embedding.weight.data = embedding_data - embedding.weight.data[512:, :] = 0 + embedding.weight.data[vocab_size:, :] = 0 expanded_embedding = VocabParallelEmbedding( - 512 + lora_config.lora_extra_vocab_size * max_loras, + vocab_size + lora_config.lora_extra_vocab_size * max_loras, 256, - org_num_embeddings=512) - expanded_embedding.weight.data[:512, :] = embedding_data + org_num_embeddings=vocab_size) + expanded_embedding.weight.data[:vocab_size, :] = embedding_data # We need to deepcopy the embedding as it will be modified # in place lora_embedding = VocabParallelEmbeddingWithLoRA( @@ -298,7 +303,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: id_to_index, layer=lora_embedding, layer_weights=torch.zeros( - (256, 512 + lora_config.lora_extra_vocab_size)), + (256, vocab_size + lora_config.lora_extra_vocab_size)), generate_embeddings_tensor=256, ) @@ -316,7 +321,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=list(lora_dict.keys()), num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) @@ -327,16 +332,18 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: for input_, original_input_, lora_id in zip(inputs, original_inputs, prompt_mapping): embedding_id = lora_id - 1 - input_[-1] = 512 + (embedding_id * embeddings_tensor_len) - original_input_[-1] = 512 - input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) - original_input_[-2] = 512 + embeddings_tensor_len - 1 + input_[-1] = vocab_size + (embedding_id * embeddings_tensor_len) + original_input_[-1] = vocab_size + input_[-2] = vocab_size + ( + (embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = vocab_size + embeddings_tensor_len - 1 mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) - expanded_embedding.weight[512:512 + + expanded_embedding.weight[vocab_size:vocab_size + (embeddings_tensor_len * max_loras)] = torch.cat(embeddings_tensors) @@ -370,14 +377,15 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: active_lora_ids=[0], num_inputs=num_loras * 3, input_size=(200, ), - input_range=(1, 512), + input_range=(1, vocab_size), ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) original_inputs = deepcopy(inputs) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 512, lora_config.lora_extra_vocab_size) + vocab_size, + lora_config.lora_extra_vocab_size) lora_embedding.set_mapping(*mapping_info, ) lora_result = lora_embedding(torch.cat(original_inputs)) @@ -393,7 +401,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: +@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000]) +def test_lm_head_logits_processor(dist_init, num_loras, device, + vocab_size) -> None: torch.set_default_device(device) max_loras = 8 @@ -402,12 +412,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_dtype=torch.float16) def _pretest(): - linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, - 1024, 32000) + linear = ParallelLMHead(vocab_size + lora_config.lora_extra_vocab_size, + 1024, vocab_size) linear.weight.data = torch.rand_like(linear.weight.data) - linear.weight.data[:, 32000:] = 0 + linear.weight.data[:, vocab_size:] = 0 logits_processor = LogitsProcessor( - 32000 + lora_config.lora_extra_vocab_size, 32000) + vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( logits_processor, 1024, linear.weight.dtype, linear.weight.device) lora_logits_processor.create_lora_weights(max_loras, lora_config) @@ -444,7 +454,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size, ) lora_logits_processor.set_mapping(*mapping_info, ) @@ -460,7 +470,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: org_vocab_size:logits_processor.org_vocab_size + embeddings_tensor_len] = embeddings_tensor - logits_processor.org_vocab_size = (32000 + + logits_processor.org_vocab_size = (vocab_size + lora_config.lora_extra_vocab_size) expected_results = [] for input_, lora_id in zip(inputs, prompt_mapping): @@ -468,11 +478,11 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: result = logits_processor._get_logits(hidden_states=input_, embedding=linear.weight, embedding_bias=None) - result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result[:, vocab_size + embeddings_tensor_len:] = float("-inf") result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling expected_results.append(result) expected_result = torch.cat(expected_results) - logits_processor.org_vocab_size = 32000 + logits_processor.org_vocab_size = vocab_size # Check that resetting the lora weights succeeds @@ -489,14 +499,14 @@ def test_lm_head_logits_processor(dist_init, num_loras, device) -> None: lora_mapping = LoRAMapping(index_mapping, prompt_mapping) mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, - 32000, + vocab_size, lora_config.lora_extra_vocab_size) lora_logits_processor.set_mapping(*mapping_info, ) lora_result = lora_logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, - embedding_bias=None)[:, :32000] + embedding_bias=None)[:, :vocab_size] expected_result = logits_processor._get_logits( hidden_states=torch.cat(inputs), embedding=original_weight, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 2736a1c7..cab8b44c 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,10 +43,51 @@ def _lora_ref_impl( H1 = H2 = [ - 128, 256, 512, 1024, 1152, 1280, 1536, 2048, 2304, 2560, 2752, 3072, 3456, - 3584, 4096, 4608, 5120, 5504, 5632, 6144, 6848, 6912, 7168, 8192, 9216, - 10240, 11008, 13824, 14336, 22016, 24576, 27392, 32000, 32256, 32512, - 32768, 33024 + 128, + 256, + 512, + 1024, + 1152, + 1280, + 1536, + 2048, + 2304, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 4608, + 5120, + 5504, + 5632, + 6144, + 6848, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 22016, + 24576, + 27392, + 32000, + 32256, + 32512, + 32768, + 33024, + 36864, + 49152, + 64000, + 64256, + 102400, + 102656, + 128000, + 128256, ] SEED = [0xabcdabcd987] CUDA_DEVICES = [ diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ec4dcf..5456b561 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -935,9 +935,9 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): model_config: Optional[PretrainedConfig] = None, ) -> None: # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - if 32000 < self.base_layer.vocab_size > 33024: + if 32000 < self.base_layer.vocab_size > 128512: raise ValueError("When using LoRA, vocab size must be " - "32000 >= vocab_size <= 33024") + "32000 >= vocab_size <= 128512") self.lora_a_stacked = torch.zeros( ( max_loras,