[BugFix] fix num_lookahead_slots missing in async executor (#4165)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
This commit is contained in:
leiwen83 2024-05-01 01:12:59 +08:00 committed by GitHub
parent 26f2fb5113
commit 4bb53e2dde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 163 additions and 19 deletions

View File

@ -1,10 +1,127 @@
from typing import List, Tuple import asyncio
from typing import List, Optional, Tuple, Union
import pytest import pytest
import ray
from tests.conftest import cleanup from tests.conftest import cleanup
from vllm import LLM from vllm import LLM
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, random_uuid
class AsyncLLM:
"""AsyncLLM
Note: Current LLM class in vllm don't support async mode, for test purpose,
we implement async one in here. Maybe we could move to
vllm/entrypoints/llm.py in future.
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
to make to work in async mode.
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
self.engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
engine_use_ray=True,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.request_counter = Counter()
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
llm_engine = AsyncLLMEngine.from_engine_args(
self.engine_args, usage_context=UsageContext.LLM_CLASS)
if prompts is None:
raise ValueError("prompts must be provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None:
num_requests = len(prompts)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
elif isinstance(sampling_params,
list) and len(sampling_params) != num_requests:
raise ValueError("The lengths of prompts and "
"sampling_params must be the same.")
async def get_output(prompt, sampling_param) -> str:
request_id = random_uuid()
results_generator = llm_engine.generate(prompt, sampling_param,
request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
return final_output
outputs = []
try:
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
res = asyncio.run(get_output(prompt, sampling_params))
outputs.append(res)
finally:
ray.shutdown()
return outputs
@pytest.fixture @pytest.fixture
@ -36,8 +153,12 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
def generator_inner(): def generator_inner():
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}') print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
llm = LLM(**kwargs)
use_async = False
if "use_async" in kwargs:
use_async = kwargs.pop("use_async")
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
set_random_seed(seed) set_random_seed(seed)
yield llm yield llm

View File

@ -42,10 +42,17 @@ def test_spec_decode_xfail_ray(test_llm_generator):
temperature=temperature, temperature=temperature,
) )
with pytest.raises(AssertionError, try:
with pytest.raises(
AssertionError,
match="Speculative decoding not yet supported for "): match="Speculative decoding not yet supported for "):
get_output_from_llm_generator(test_llm_generator, prompts, get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params) sampling_params)
finally:
# we need to free up ray resource,
# so that latter test could use the gpu we allocated here
import ray
ray.shutdown()
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -40,7 +40,8 @@ from .conftest import get_output_from_llm_generator
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [
{
# Use a small model for a fast test. # Use a small model for a fast test.
# Note this is repeated in the test body; to initialize a tokenizer. # Note this is repeated in the test body; to initialize a tokenizer.
"model": "JackFram/llama-68m", "model": "JackFram/llama-68m",
@ -49,8 +50,14 @@ from .conftest import get_output_from_llm_generator
"enforce_eager": True, "enforce_eager": True,
# Required for spec decode. # Required for spec decode.
"use_v2_block_manager": True "use_v2_block_manager": True,
}])
# whether use AsyncLLM engine
"use_async": async_mode,
}
# Try both async and sync engine execution
for async_mode in [True, False]
])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"per_test_common_llm_kwargs", "per_test_common_llm_kwargs",
[ [

View File

@ -211,9 +211,11 @@ class _AsyncLLMEngine(LLMEngine):
if not scheduler_outputs.is_empty(): if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
output = await self.model_executor.execute_model_async( output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, seq_group_metadata_list,
scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out, scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy) scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots)
else: else:
output = [] output = []

View File

@ -109,12 +109,14 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)( output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy) blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output return output
async def check_health_async(self) -> None: async def check_health_async(self) -> None:

View File

@ -112,6 +112,7 @@ class ExecutorAsyncBase(ExecutorBase):
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
"""Executes one model step on the given sequences.""" """Executes one model step on the given sequences."""
raise NotImplementedError raise NotImplementedError

View File

@ -163,10 +163,12 @@ class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)( output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy) blocks_to_copy=blocks_to_copy,
num_lookahead_slots=num_lookahead_slots)
return output return output

View File

@ -84,6 +84,7 @@ class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase):
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
num_lookahead_slots: int,
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
output = await make_async(self.driver_worker.execute_model)( output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list, ) seq_group_metadata_list=seq_group_metadata_list, )

View File

@ -196,6 +196,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
"blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out, "blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy, "blocks_to_copy": blocks_to_copy,
"num_lookahead_slots": num_lookahead_slots,
}, },
use_ray_compiled_dag=USE_RAY_COMPILED_DAG) use_ray_compiled_dag=USE_RAY_COMPILED_DAG)