From 6a0f617210dfba76f3db4db1155d1f1489609133 Mon Sep 17 00:00:00 2001 From: SangBin Cho Date: Fri, 10 May 2024 23:54:32 +0900 Subject: [PATCH] [Core] Fix circular reference which leaked llm instance in local dev env (#4737) Storing exception frame is extremely prone to circular refernece because it contains the reference to objects. When tensorizer is not installed, it leaks llm instance because error frame has references to various modules which cause circular reference problem. I also found spec decoding has a circular reference issue, and I solved it using weakref.proxy. --- tests/basic_correctness/test_basic_correctness.py | 13 +++++++++++++ vllm/model_executor/model_loader/tensorizer.py | 10 +++++----- vllm/spec_decode/multi_step_worker.py | 3 ++- vllm/spec_decode/ngram_worker.py | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index d75279dd..7d811744 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -3,9 +3,12 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os +import weakref import pytest +from vllm import LLM + MODELS = [ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", @@ -13,6 +16,16 @@ MODELS = [ VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND" +def test_vllm_gc_ed(): + """Verify vllm instance is GC'ed when it is deleted""" + llm = LLM("facebook/opt-125m") + weak_llm = weakref.ref(llm) + del llm + # If there's any circular reference to vllm, this fails + # because llm instance is not GC'ed. + assert weak_llm() is None + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index af433b86..219a2a39 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -tensorizer_load_fail = None +tensorizer_error_msg = None try: from tensorizer import (DecryptionParams, EncryptionParams, @@ -28,7 +28,7 @@ try: from tensorizer.utils import (convert_bytes, get_mem_usage, no_init_or_tensor) except ImportError as e: - tensorizer_load_fail = e + tensorizer_error_msg = str(e) __all__ = [ 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', @@ -254,11 +254,11 @@ class TensorizerAgent: def __init__(self, tensorizer_config: TensorizerConfig, quant_config: QuantizationConfig, **extra_kwargs): - if tensorizer_load_fail is not None: + if tensorizer_error_msg is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " - "to use this feature with `pip install vllm[tensorizer]`." - ) from tensorizer_load_fail + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) self.tensorizer_config = tensorizer_config self.tensorizer_args = ( diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 5044cc1e..20098eba 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -1,4 +1,5 @@ import copy +import weakref from typing import List, Tuple import torch @@ -32,7 +33,7 @@ class MultiStepWorker(Worker): super().init_device() self._proposer = Top1Proposer( - self, + weakref.proxy(self), self.device, self.vocab_size, max_proposal_len=self.max_model_len, diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index f18f9387..6cd50fcc 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -1,3 +1,4 @@ +import weakref from typing import List, Optional, Tuple import torch @@ -37,7 +38,7 @@ class NGramWorker(LoraNotSupportedWorkerBase): # Current only support Top1Proposer self._proposer = Top1Proposer( - self, + weakref.proxy(self), device=self.device, vocab_size=self.vocab_size, )