[misc] fix custom allreduce p2p cache file generation (#7853)

This commit is contained in:
youkaichao 2024-08-26 15:02:25 -07:00 committed by GitHub
parent dd9857f5fa
commit 05826c887b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ import os
import pickle import pickle
import subprocess import subprocess
import sys import sys
import tempfile
from itertools import product from itertools import product
from typing import Dict, List, Optional, Sequence from typing import Dict, List, Optional, Sequence
@ -211,7 +212,13 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
# However, `can_actually_p2p` requires spawn method. # However, `can_actually_p2p` requires spawn method.
# The fix is, we use `subprocess` to call the function, # The fix is, we use `subprocess` to call the function,
# where we have `if __name__ == "__main__":` in this file. # where we have `if __name__ == "__main__":` in this file.
input_bytes = pickle.dumps((batch_src, batch_tgt))
# use a temporary file to store the result
# we don't use the output of the subprocess directly,
# because the subprocess might produce logging output
with tempfile.NamedTemporaryFile() as output_file:
input_bytes = pickle.dumps(
(batch_src, batch_tgt, output_file.name))
returned = subprocess.run([sys.executable, __file__], returned = subprocess.run([sys.executable, __file__],
input=input_bytes, input=input_bytes,
capture_output=True) capture_output=True)
@ -224,7 +231,8 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
f"Error happened when batch testing " f"Error happened when batch testing "
f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" f"peer-to-peer access from {batch_src} to {batch_tgt}:\n"
f"{returned.stderr.decode()}") from e f"{returned.stderr.decode()}") from e
result = pickle.loads(returned.stdout) with open(output_file.name, "rb") as f:
result = pickle.load(f)
for _i, _j, r in zip(batch_src, batch_tgt, result): for _i, _j, r in zip(batch_src, batch_tgt, result):
cache[f"{_i}->{_j}"] = r cache[f"{_i}->{_j}"] = r
with open(path, "w") as f: with open(path, "w") as f:
@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
__all__ = ["gpu_p2p_access_check"] __all__ = ["gpu_p2p_access_check"]
if __name__ == "__main__": if __name__ == "__main__":
batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read()) batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read())
result = can_actually_p2p(batch_src, batch_tgt) result = can_actually_p2p(batch_src, batch_tgt)
sys.stdout.buffer.write(pickle.dumps(result)) with open(output_file, "wb") as f:
f.write(pickle.dumps(result))