diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b6160403..c460e2e0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -84,6 +84,8 @@ class Worker: torch.cuda.set_device(self.device) _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -126,7 +128,9 @@ class Worker: # profiled peak memory. torch.cuda.synchronize() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - peak_memory = total_gpu_memory - free_gpu_memory + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory cache_block_size = CacheEngine.get_cache_block_size( block_size, cache_dtype, self.model_config, self.parallel_config)