[Model][LoRA]LoRA support added for LlamaEmbeddingModel (#10071)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2024-11-06 17:49:19 +08:00 committed by GitHub
parent 6a585a23d2
commit 2003cc3513
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 2 deletions

View File

@ -333,7 +333,7 @@ Text Embedding
* - :code:`MistralModel`
- Mistral-based
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
-
- ✅︎
- ✅︎
.. important::

View File

@ -627,7 +627,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
return name, loaded_weight
class LlamaEmbeddingModel(nn.Module, SupportsPP):
class LlamaEmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
"""
A model that uses Llama with additional embedding functionalities.
@ -638,6 +638,19 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
model: An instance of LlamaModel used for forward operations.
_pooler: An instance of Pooler used for pooling operations.
"""
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
]
embedding_modules = {
"embed_tokens": "input_embeddings",
}
embedding_padding_modules = []
def __init__(
self,
@ -679,3 +692,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)
# LRUCacheWorkerLoRAManager instantiation requires model config.
@property
def config(self):
return self.model.config