[BugFix] Fix Embedding Models with TP>1 (#5075)
This commit is contained in:
parent
d4f3985907
commit
9ba415588a
@ -79,6 +79,10 @@ class EmbeddingModelRunner(ModelRunner):
|
||||
execute_model_kwargs.update({"image_input": multi_modal_input})
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return None
|
||||
|
||||
return self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user