[Model] LoRA gptbigcode implementation (#3949)

This commit is contained in:
raywanb 2024-05-23 04:58:59 +08:00 committed by GitHub
parent a3a73ab069
commit 97b030005c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 34 additions and 5 deletions

View File

@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \ f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \ f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4096) \
@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7168) \
@ -97,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \ f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \
@ -105,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \ f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \

View File

@ -58,6 +58,7 @@ H1 = H2 = [
2560, 2560,
2752, 2752,
3072, 3072,
3328,
3456, 3456,
3584, 3584,
4096, 4096,
@ -66,6 +67,7 @@ H1 = H2 = [
5504, 5504,
5632, 5632,
6144, 6144,
6400,
6848, 6848,
6912, 6912,
7168, 7168,

View File

@ -310,7 +310,9 @@ class LoRAModel:
if part_name not in expected_lora_modules: if part_name not in expected_lora_modules:
unexpected_modules.append(module) unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules # loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules: if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError( raise ValueError(
f"While loading {lora_dir}, expected" f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"

View File

@ -25,7 +25,7 @@ from torch import nn
from transformers import GPTBigCodeConfig from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -191,14 +191,19 @@ class GPTBigCodeModel(nn.Module):
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
assert not config.add_cross_attention assert not config.add_cross_attention
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
lora_vocab = (lora_config.lora_extra_vocab_size *
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) (lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(self.vocab_size,
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.h = nn.ModuleList([
GPTBigCodeBlock(config, cache_config, quant_config) GPTBigCodeBlock(config, cache_config, quant_config)
@ -226,19 +231,35 @@ class GPTBigCodeModel(nn.Module):
class GPTBigCodeForCausalLM(nn.Module): class GPTBigCodeForCausalLM(nn.Module):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = []
def __init__( def __init__(
self, self,
config: GPTBigCodeConfig, config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config) self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
def forward( def forward(