[Gemma2] add bitsandbytes support for Gemma2 (#8338)

This commit is contained in:
Blueyo0 2024-09-12 12:53:12 +08:00 committed by GitHub
parent 5a60699c45
commit 1bf2dd9df0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer.
embedding_modules = {}
embedding_padding_modules = []
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,