[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage (#9352)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
48138a8415
commit
de4008e2ab
@ -26,10 +26,12 @@ def test_lazy_outlines(sample_regex):
|
||||
# make sure outlines is not imported
|
||||
assert 'outlines' not in sys.modules
|
||||
|
||||
# The second LLM needs to request a higher gpu_memory_utilization because
|
||||
# the first LLM has already allocated a full 30% of the gpu memory.
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
enforce_eager=True,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
gpu_memory_utilization=0.3)
|
||||
gpu_memory_utilization=0.6)
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
outputs = llm.generate(
|
||||
prompts=[
|
||||
|
||||
@ -44,7 +44,7 @@ def test_offline_mode(llm: LLM, monkeypatch):
|
||||
LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.10,
|
||||
gpu_memory_utilization=0.20,
|
||||
enforce_eager=True)
|
||||
finally:
|
||||
# Reset the environment after the test
|
||||
|
||||
69
tests/worker/test_profile.py
Normal file
69
tests/worker/test_profile.py
Normal file
@ -0,0 +1,69 @@
|
||||
import torch
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
def test_gpu_memory_profiling():
|
||||
# Tests the gpu profiling that happens in order to determine the number of
|
||||
# KV cache blocks that we can allocate on the GPU.
|
||||
# This test mocks the maximum available gpu memory so that it can run on
|
||||
# any gpu setup.
|
||||
|
||||
# Set up engine args to build a worker.
|
||||
engine_args = EngineArgs(model="facebook/opt-125m",
|
||||
dtype="half",
|
||||
load_format="dummy")
|
||||
engine_config = engine_args.create_engine_config()
|
||||
engine_config.cache_config.num_gpu_blocks = 1000
|
||||
engine_config.cache_config.num_cpu_blocks = 1000
|
||||
|
||||
# Create the worker.
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
worker = Worker(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
|
||||
# Load the model so we can profile it
|
||||
worker.init_device()
|
||||
worker.load_model()
|
||||
|
||||
# Set 10GiB as the total gpu ram to be device-agnostic
|
||||
def mock_mem_info():
|
||||
current_usage = torch.cuda.memory_stats(
|
||||
)["allocated_bytes.all.current"]
|
||||
mock_total_bytes = 10 * 1024**3
|
||||
free = mock_total_bytes - current_usage
|
||||
|
||||
return (free, mock_total_bytes)
|
||||
|
||||
from unittest.mock import patch
|
||||
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
|
||||
gpu_blocks, _ = worker.determine_num_available_blocks()
|
||||
|
||||
# Peak vram usage by torch should be 0.7077 GiB
|
||||
# Non-torch allocations should be 0.0079 GiB
|
||||
# 9.0 GiB should be the utilization target
|
||||
# 8.2843 GiB should be available for the KV cache
|
||||
block_size = CacheEngine.get_cache_block_size(
|
||||
engine_config.cache_config, engine_config.model_config,
|
||||
engine_config.parallel_config)
|
||||
|
||||
expected_blocks = (8.2843 * 1024**3) // block_size
|
||||
|
||||
# Check within a small tolerance for portability
|
||||
# Hardware, kernel, or dependency changes could all affect memory
|
||||
# utilization
|
||||
assert abs(gpu_blocks - expected_blocks) < 5
|
||||
@ -217,42 +217,76 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
self._assert_memory_footprint_increased_during_profiling()
|
||||
|
||||
# Get the peak memory allocation recorded by torch
|
||||
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||
|
||||
# Check for any memory left around that may have been allocated on the
|
||||
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||
# GB during a forward pass
|
||||
torch.cuda.empty_cache()
|
||||
# After emptying the torch cache, any other increase in gpu ram should
|
||||
# be from non-torch allocations.
|
||||
non_torch_allocations = free_memory_pre_profile - \
|
||||
torch.cuda.mem_get_info()[0]
|
||||
if non_torch_allocations > 0:
|
||||
peak_memory += non_torch_allocations
|
||||
|
||||
available_kv_cache_memory = (
|
||||
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory)
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize()
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
if cache_block_size == 0:
|
||||
num_gpu_blocks = 0
|
||||
num_cpu_blocks = 0
|
||||
else:
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
logger.info(
|
||||
"Memory profiling results: total_gpu_memory=%.2fGiB"
|
||||
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
|
||||
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
|
||||
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
|
||||
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
|
||||
(peak_memory - non_torch_allocations) / (1024**3),
|
||||
non_torch_allocations / (1024**3),
|
||||
available_kv_cache_memory / (1024**3),
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
|
||||
# Final cleanup
|
||||
if self.model_runner.lora_manager:
|
||||
self.model_runner.remove_all_loras()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _assert_memory_footprint_increased_during_profiling(self):
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||
assert self.init_gpu_memory - free_gpu_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||
|
||||
Loading…
Reference in New Issue
Block a user