[Bugfix] Add synchronize to prevent possible data race (#6788)

Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2024-07-25 13:40:01 -04:00 committed by GitHub
parent 65b1f121c8
commit 95db75de64
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -243,6 +243,13 @@ class GroupCoordinator:
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()
# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is: