[misc] fix custom allreduce p2p cache file generation (#7853)

This commit is contained in:
youkaichao 2024-08-26 15:02:25 -07:00 committed by GitHub
parent dd9857f5fa
commit 05826c887b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import os
import pickle
import subprocess
import sys
import tempfile
from itertools import product
from typing import Dict, List, Optional, Sequence
@ -211,20 +212,27 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file.
input_bytes = pickle.dumps((batch_src, batch_tgt))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
result = pickle.loads(returned.stdout)
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps(
(batch_src, batch_tgt, output_file.name))
returned = subprocess.run([sys.executable, __file__],
input=input_bytes,
capture_output=True)
# check if the subprocess is successful
try:
returned.check_returncode()
except Exception as e:
# wrap raised exception to provide more information
raise RuntimeError(
f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e
with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r
with open(path, "w") as f:
@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
__all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__":
batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read())
batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt)
sys.stdout.buffer.write(pickle.dumps(result))
with open(output_file, "wb") as f:
f.write(pickle.dumps(result))