[BugFix] Fix Embedding Models with TP>1 (#5075)

This commit is contained in:
Robert Shaw 2024-05-28 08:32:42 -07:00 committed by GitHub
parent d4f3985907
commit 9ba415588a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -79,6 +79,10 @@ class EmbeddingModelRunner(ModelRunner):
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return None
return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata)