[BugFix] Fix Embedding Models with TP>1 (#5075)
This commit is contained in:
parent
d4f3985907
commit
9ba415588a
@ -79,6 +79,10 @@ class EmbeddingModelRunner(ModelRunner):
|
|||||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||||
hidden_states = model_executable(**execute_model_kwargs)
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Only perform pooling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return None
|
||||||
|
|
||||||
return self.model.pooler(hidden_states=hidden_states,
|
return self.model.pooler(hidden_states=hidden_states,
|
||||||
pooling_metadata=pooling_metadata)
|
pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user