diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fcedf0fe..e922beba 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -250,15 +250,13 @@ class NCCLCommunicator: assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device - current_device = torch.cuda.current_device() - try: - torch.cuda.set_device(device) + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): NCCL_CHECK( _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, self.unique_id, self.rank)) self.stream = torch.cuda.Stream() - finally: - torch.cuda.set_device(current_device) def all_reduce(self, tensor: torch.Tensor,