[Core][Distributed] use existing torch.cuda.device (#4318)
[Core][Distributed] use existing torch.cuda.device context manager (#4318)
This commit is contained in:
parent
468d761b32
commit
3cd9b5bb2d
@ -250,15 +250,13 @@ class NCCLCommunicator:
|
|||||||
assert isinstance(device, torch.device)
|
assert isinstance(device, torch.device)
|
||||||
self.device = device
|
self.device = device
|
||||||
# nccl communicator and stream will use this device
|
# nccl communicator and stream will use this device
|
||||||
current_device = torch.cuda.current_device()
|
# `torch.cuda.device` is a context manager that changes the
|
||||||
try:
|
# current cuda device to the specified one
|
||||||
torch.cuda.set_device(device)
|
with torch.cuda.device(device):
|
||||||
NCCL_CHECK(
|
NCCL_CHECK(
|
||||||
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
_c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size,
|
||||||
self.unique_id, self.rank))
|
self.unique_id, self.rank))
|
||||||
self.stream = torch.cuda.Stream()
|
self.stream = torch.cuda.Stream()
|
||||||
finally:
|
|
||||||
torch.cuda.set_device(current_device)
|
|
||||||
|
|
||||||
def all_reduce(self,
|
def all_reduce(self,
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user