[misc] fix custom allreduce p2p cache file generation (#7853)

This commit is contained in:
youkaichao 2024-08-26 15:02:25 -07:00 committed by GitHub
parent dd9857f5fa
commit 05826c887b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import os
import pickle import pickle
import subprocess import subprocess
import sys import sys
import tempfile
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
@ -211,20 +212,27 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# However, `can_actually_p2p` requires spawn method. # However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function, # The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file. # where we have `if __name__ == "__main__":` in this file.
input_bytes = pickle.dumps((batch_src, batch_tgt))
returned = subprocess.run([sys.executable, __file__], # use a temporary file to store the result
input=input_bytes, # we don't use the output of the subprocess directly,
capture_output=True) # because the subprocess might produce logging output
# check if the subprocess is successful with tempfile.NamedTemporaryFile() as output_file:
try: input_bytes = pickle.dumps(
returned.check_returncode() (batch_src, batch_tgt, output_file.name))
except Exception as e: returned = subprocess.run([sys.executable, __file__],
# wrap raised exception to provide more information input=input_bytes,
raise RuntimeError( capture_output=True)
f"Error happened when batch testing " # check if the subprocess is successful
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" try:
f"{returned.stderr.decode()}") from e returned.check_returncode()
result = pickle.loads(returned.stdout) except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result): for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r cache[f"{_i}->{_j}"] = r
with open(path, "w") as f: with open(path, "w") as f:
@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
__all__ = ["gpu_p2p_access_check"] __all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__": if __name__ == "__main__":
batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read()) batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt) result = can_actually_p2p(batch_src, batch_tgt)
sys.stdout.buffer.write(pickle.dumps(result)) with open(output_file, "wb") as f:
f.write(pickle.dumps(result))