[Bugfix][Core] Use torch.cuda.memory_stats() to profile peak memory usage (#9352)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
parent
48138a8415
commit
de4008e2ab
@ -26,10 +26,12 @@ def test_lazy_outlines(sample_regex):
|
|||||||
# make sure outlines is not imported
|
# make sure outlines is not imported
|
||||||
assert 'outlines' not in sys.modules
|
assert 'outlines' not in sys.modules
|
||||||
|
|
||||||
|
# The second LLM needs to request a higher gpu_memory_utilization because
|
||||||
|
# the first LLM has already allocated a full 30% of the gpu memory.
|
||||||
llm = LLM(model="facebook/opt-125m",
|
llm = LLM(model="facebook/opt-125m",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
guided_decoding_backend="lm-format-enforcer",
|
guided_decoding_backend="lm-format-enforcer",
|
||||||
gpu_memory_utilization=0.3)
|
gpu_memory_utilization=0.6)
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=[
|
prompts=[
|
||||||
|
|||||||
@ -44,7 +44,7 @@ def test_offline_mode(llm: LLM, monkeypatch):
|
|||||||
LLM(model=MODEL_NAME,
|
LLM(model=MODEL_NAME,
|
||||||
max_num_batched_tokens=4096,
|
max_num_batched_tokens=4096,
|
||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
gpu_memory_utilization=0.10,
|
gpu_memory_utilization=0.20,
|
||||||
enforce_eager=True)
|
enforce_eager=True)
|
||||||
finally:
|
finally:
|
||||||
# Reset the environment after the test
|
# Reset the environment after the test
|
||||||
|
|||||||
69
tests/worker/test_profile.py
Normal file
69
tests/worker/test_profile.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
|
||||||
|
|
||||||
|
def test_gpu_memory_profiling():
|
||||||
|
# Tests the gpu profiling that happens in order to determine the number of
|
||||||
|
# KV cache blocks that we can allocate on the GPU.
|
||||||
|
# This test mocks the maximum available gpu memory so that it can run on
|
||||||
|
# any gpu setup.
|
||||||
|
|
||||||
|
# Set up engine args to build a worker.
|
||||||
|
engine_args = EngineArgs(model="facebook/opt-125m",
|
||||||
|
dtype="half",
|
||||||
|
load_format="dummy")
|
||||||
|
engine_config = engine_args.create_engine_config()
|
||||||
|
engine_config.cache_config.num_gpu_blocks = 1000
|
||||||
|
engine_config.cache_config.num_cpu_blocks = 1000
|
||||||
|
|
||||||
|
# Create the worker.
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
get_ip(), get_open_port())
|
||||||
|
worker = Worker(
|
||||||
|
model_config=engine_config.model_config,
|
||||||
|
parallel_config=engine_config.parallel_config,
|
||||||
|
scheduler_config=engine_config.scheduler_config,
|
||||||
|
device_config=engine_config.device_config,
|
||||||
|
cache_config=engine_config.cache_config,
|
||||||
|
load_config=engine_config.load_config,
|
||||||
|
local_rank=0,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
is_driver_worker=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load the model so we can profile it
|
||||||
|
worker.init_device()
|
||||||
|
worker.load_model()
|
||||||
|
|
||||||
|
# Set 10GiB as the total gpu ram to be device-agnostic
|
||||||
|
def mock_mem_info():
|
||||||
|
current_usage = torch.cuda.memory_stats(
|
||||||
|
)["allocated_bytes.all.current"]
|
||||||
|
mock_total_bytes = 10 * 1024**3
|
||||||
|
free = mock_total_bytes - current_usage
|
||||||
|
|
||||||
|
return (free, mock_total_bytes)
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
|
||||||
|
gpu_blocks, _ = worker.determine_num_available_blocks()
|
||||||
|
|
||||||
|
# Peak vram usage by torch should be 0.7077 GiB
|
||||||
|
# Non-torch allocations should be 0.0079 GiB
|
||||||
|
# 9.0 GiB should be the utilization target
|
||||||
|
# 8.2843 GiB should be available for the KV cache
|
||||||
|
block_size = CacheEngine.get_cache_block_size(
|
||||||
|
engine_config.cache_config, engine_config.model_config,
|
||||||
|
engine_config.parallel_config)
|
||||||
|
|
||||||
|
expected_blocks = (8.2843 * 1024**3) // block_size
|
||||||
|
|
||||||
|
# Check within a small tolerance for portability
|
||||||
|
# Hardware, kernel, or dependency changes could all affect memory
|
||||||
|
# utilization
|
||||||
|
assert abs(gpu_blocks - expected_blocks) < 5
|
||||||
@ -217,42 +217,76 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
|
||||||
|
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
self.model_runner.profile_run()
|
self.model_runner.profile_run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
self._assert_memory_footprint_increased_during_profiling()
|
||||||
|
|
||||||
|
# Get the peak memory allocation recorded by torch
|
||||||
|
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
||||||
|
|
||||||
|
# Check for any memory left around that may have been allocated on the
|
||||||
|
# gpu outside of `torch`. NCCL operations, for example, can use a few
|
||||||
|
# GB during a forward pass
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# After emptying the torch cache, any other increase in gpu ram should
|
||||||
|
# be from non-torch allocations.
|
||||||
|
non_torch_allocations = free_memory_pre_profile - \
|
||||||
|
torch.cuda.mem_get_info()[0]
|
||||||
|
if non_torch_allocations > 0:
|
||||||
|
peak_memory += non_torch_allocations
|
||||||
|
|
||||||
|
available_kv_cache_memory = (
|
||||||
|
total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||||
|
peak_memory)
|
||||||
|
|
||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
torch.cuda.synchronize()
|
|
||||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
|
||||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
|
||||||
# GPU did not change their memory usage during the profiling.
|
|
||||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
|
||||||
assert peak_memory > 0, (
|
|
||||||
"Error in memory profiling. "
|
|
||||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
|
||||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
|
||||||
"not properly cleaned up before initializing the vLLM instance.")
|
|
||||||
|
|
||||||
cache_block_size = self.get_cache_block_size_bytes()
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
if cache_block_size == 0:
|
if cache_block_size == 0:
|
||||||
num_gpu_blocks = 0
|
num_gpu_blocks = 0
|
||||||
num_cpu_blocks = 0
|
num_cpu_blocks = 0
|
||||||
else:
|
else:
|
||||||
num_gpu_blocks = int(
|
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
|
||||||
peak_memory) // cache_block_size)
|
|
||||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
cache_block_size)
|
cache_block_size)
|
||||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Memory profiling results: total_gpu_memory=%.2fGiB"
|
||||||
|
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
|
||||||
|
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
|
||||||
|
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
|
||||||
|
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
|
||||||
|
(peak_memory - non_torch_allocations) / (1024**3),
|
||||||
|
non_torch_allocations / (1024**3),
|
||||||
|
available_kv_cache_memory / (1024**3),
|
||||||
|
self.cache_config.gpu_memory_utilization)
|
||||||
|
|
||||||
|
# Final cleanup
|
||||||
if self.model_runner.lora_manager:
|
if self.model_runner.lora_manager:
|
||||||
self.model_runner.remove_all_loras()
|
self.model_runner.remove_all_loras()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _assert_memory_footprint_increased_during_profiling(self):
|
||||||
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
|
# GPU did not change their memory usage during the profiling.
|
||||||
|
free_gpu_memory, _ = torch.cuda.mem_get_info()
|
||||||
|
assert self.init_gpu_memory - free_gpu_memory > 0, (
|
||||||
|
"Error in memory profiling. "
|
||||||
|
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||||
|
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||||
|
"not properly cleaned up before initializing the vLLM instance.")
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int,
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
num_cpu_blocks: int) -> None:
|
num_cpu_blocks: int) -> None:
|
||||||
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
"""Allocate GPU and CPU KV cache with the specified number of blocks.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user