From 7d9ffa2ae102cbfae65035c511f8d3c8e5fab986 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sat, 24 Aug 2024 00:51:38 -0700 Subject: [PATCH] [misc][core] lazy import outlines (#7831) --- .buildkite/test-pipeline.yaml | 3 +- tests/entrypoints/llm/test_lazy_outlines.py | 48 +++++++++++++++++++ .../guided_decoding/__init__.py | 9 ++-- .../lm_format_enforcer_decoding.py | 11 +++-- 4 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/llm/test_lazy_outlines.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 283776c0..e4069386 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -87,7 +87,8 @@ steps: commands: - pip install -e ./plugins/vllm_add_dummy_model - pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@a4987bba6e9e9b3f22bd3a6c1ecf0abd04fd5622#egg=lm_eval[api] - - pytest -v -s entrypoints/llm + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py + - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai - label: Distributed Tests (4 GPUs) # 10min diff --git a/tests/entrypoints/llm/test_lazy_outlines.py b/tests/entrypoints/llm/test_lazy_outlines.py new file mode 100644 index 00000000..39480531 --- /dev/null +++ b/tests/entrypoints/llm/test_lazy_outlines.py @@ -0,0 +1,48 @@ +import sys + +from vllm import LLM, SamplingParams + + +def test_lazy_outlines(sample_regex): + """If users don't use guided decoding, outlines should not be imported. + """ + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + llm = LLM(model="facebook/opt-125m", + enforce_eager=True, + gpu_memory_utilization=0.3) + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # make sure outlines is not imported + assert 'outlines' not in sys.modules + + llm = LLM(model="facebook/opt-125m", + enforce_eager=True, + guided_decoding_backend="lm-format-enforcer", + gpu_memory_utilization=0.3) + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + outputs = llm.generate( + prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + # make sure outlines is not imported + assert 'outlines' not in sys.modules diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4a2476dd..f9fcdead 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import ( CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) -from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, - get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor( request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': @@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor( # request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) if guided_decoding_backend == 'lm-format-enforcer': diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 8de811a6..51f94798 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) -from vllm.model_executor.guided_decoding.outlines_decoding import ( - get_local_outlines_guided_decoding_logits_processor, - get_outlines_guided_decoding_logits_processor) from vllm.sampling_params import LogitsProcessor @@ -43,6 +40,10 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( character_level_parser = RegexParser(request.guided_regex) elif request.guided_grammar: # CFG grammar not supported by LMFE, revert to outlines + + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_outlines_guided_decoding_logits_processor) return await get_outlines_guided_decoding_logits_processor( request, tokenizer) elif (request.response_format is not None @@ -87,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( character_level_parser = RegexParser(guided_options.guided_regex) elif guided_options.guided_grammar: # CFG grammar not supported by LMFE, revert to outlines + + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( + get_local_outlines_guided_decoding_logits_processor) return get_local_outlines_guided_decoding_logits_processor( guided_options, tokenizer) elif guided_options.guided_json_object: