From 3cd9b5bb2d4a0d5eed07186ae140f5dc8f839708 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 24 Apr 2024 09:00:20 -0700 Subject: [PATCH] [Core][Distributed] use existing torch.cuda.device (#4318) [Core][Distributed] use existing torch.cuda.device context manager (#4318) --- vllm/distributed/device_communicators/pynccl.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fcedf0fe..e922beba 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -250,15 +250,13 @@ class NCCLCommunicator: assert isinstance(device, torch.device) self.device = device # nccl communicator and stream will use this device - current_device = torch.cuda.current_device() - try: - torch.cuda.set_device(device) + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): NCCL_CHECK( _c_ncclCommInitRank(ctypes.byref(self.comm), self.world_size, self.unique_id, self.rank)) self.stream = torch.cuda.Stream() - finally: - torch.cuda.set_device(current_device) def all_reduce(self, tensor: torch.Tensor,