[Bugfix] remove post_layernorm in siglip (#8106)
This commit is contained in:
parent
ccd7207191
commit
d3311562fb
@ -443,14 +443,27 @@ class SiglipVisionTransformer(nn.Module):
|
||||
self.config = config
|
||||
embed_dim = config.hidden_size
|
||||
|
||||
if (num_hidden_layers_override is None
|
||||
or num_hidden_layers_override == config.num_hidden_layers):
|
||||
self.need_post_layernorm = True
|
||||
elif num_hidden_layers_override > config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
"num_hidden_layers_override cannot be greater than "
|
||||
"num_hidden_layers")
|
||||
else:
|
||||
self.need_post_layernorm = False
|
||||
|
||||
self.embeddings = SiglipVisionEmbeddings(config)
|
||||
self.encoder = SiglipEncoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
if self.need_post_layernorm:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
else:
|
||||
self.post_layernorm = nn.Identity()
|
||||
self.use_head = (True if not hasattr(config, "vision_use_head") else
|
||||
config.vision_use_head)
|
||||
if self.use_head:
|
||||
@ -470,7 +483,6 @@ class SiglipVisionTransformer(nn.Module):
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
|
||||
# TODO: add this back when pooled_output is used in inference
|
||||
# if self.use_head:
|
||||
# pooled_output = self.head(last_hidden_state)
|
||||
@ -499,6 +511,10 @@ class SiglipVisionModel(nn.Module):
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
)
|
||||
|
||||
@property
|
||||
def need_post_layernorm(self):
|
||||
return self.vision_model.need_post_layernorm
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@ -517,6 +533,11 @@ class SiglipVisionModel(nn.Module):
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in SiglipVisionModel
|
||||
if ("vision_model.post_layernorm" in name
|
||||
and not self.need_post_layernorm):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if "vision_model.encoder.layers." in name:
|
||||
layer_idx = int(name.split(".")[3])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user