[Bug fix][Core] fixup ngram not setup correctly (#4551)
Co-authored-by: Lei Wen <wenlei03@qiyi.com> Co-authored-by: Cade Daniel <edacih@gmail.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
469f85c782
commit
8344f7742b
@ -55,7 +55,7 @@ class AsyncLLM:
|
|||||||
) -> None:
|
) -> None:
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
self.engine_args = AsyncEngineArgs(
|
engine_args = AsyncEngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
@ -76,6 +76,8 @@ class AsyncLLM:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.request_counter = Counter()
|
self.request_counter = Counter()
|
||||||
|
self.llm_engine = AsyncLLMEngine.from_engine_args(
|
||||||
|
engine_args, usage_context=UsageContext.LLM_CLASS)
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@ -88,9 +90,6 @@ class AsyncLLM:
|
|||||||
multi_modal_data: Optional[MultiModalData] = None,
|
multi_modal_data: Optional[MultiModalData] = None,
|
||||||
) -> List[RequestOutput]:
|
) -> List[RequestOutput]:
|
||||||
|
|
||||||
llm_engine = AsyncLLMEngine.from_engine_args(
|
|
||||||
self.engine_args, usage_context=UsageContext.LLM_CLASS)
|
|
||||||
|
|
||||||
if prompts is None:
|
if prompts is None:
|
||||||
raise ValueError("prompts must be provided.")
|
raise ValueError("prompts must be provided.")
|
||||||
if isinstance(prompts, str):
|
if isinstance(prompts, str):
|
||||||
@ -111,8 +110,8 @@ class AsyncLLM:
|
|||||||
|
|
||||||
async def get_output(prompt, sampling_param) -> str:
|
async def get_output(prompt, sampling_param) -> str:
|
||||||
request_id = random_uuid()
|
request_id = random_uuid()
|
||||||
results_generator = llm_engine.generate(prompt, sampling_param,
|
results_generator = self.llm_engine.generate(
|
||||||
request_id)
|
prompt, sampling_param, request_id)
|
||||||
final_output = None
|
final_output = None
|
||||||
async for request_output in results_generator:
|
async for request_output in results_generator:
|
||||||
final_output = request_output
|
final_output = request_output
|
||||||
@ -185,12 +184,25 @@ def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
|||||||
return generator_outer
|
return generator_outer
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_assert_ngram_worker(llm):
|
||||||
|
# Verify the proposer worker is ngram if ngram is specified.
|
||||||
|
if (not isinstance(llm, AsyncLLM)
|
||||||
|
and llm.llm_engine.speculative_config is not None
|
||||||
|
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
||||||
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
||||||
|
assert isinstance(
|
||||||
|
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
||||||
|
NGramWorker)
|
||||||
|
|
||||||
|
|
||||||
def get_output_from_llm_generator(
|
def get_output_from_llm_generator(
|
||||||
llm_generator, prompts,
|
llm_generator, prompts,
|
||||||
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
||||||
tokens = []
|
tokens = []
|
||||||
token_ids = []
|
token_ids = []
|
||||||
for llm in llm_generator():
|
for llm in llm_generator():
|
||||||
|
maybe_assert_ngram_worker(llm)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
||||||
token_ids = [output.outputs[0].token_ids for output in outputs]
|
token_ids = [output.outputs[0].token_ids for output in outputs]
|
||||||
tokens = [output.outputs[0].text for output in outputs]
|
tokens = [output.outputs[0].text for output in outputs]
|
||||||
|
|||||||
@ -82,6 +82,10 @@ class GPUExecutor(ExecutorBase):
|
|||||||
draft_worker_kwargs.update(
|
draft_worker_kwargs.update(
|
||||||
model_config=self.speculative_config.draft_model_config,
|
model_config=self.speculative_config.draft_model_config,
|
||||||
parallel_config=self.speculative_config.draft_parallel_config,
|
parallel_config=self.speculative_config.draft_parallel_config,
|
||||||
|
ngram_prompt_lookup_max=self.speculative_config.
|
||||||
|
ngram_prompt_lookup_max,
|
||||||
|
ngram_prompt_lookup_min=self.speculative_config.
|
||||||
|
ngram_prompt_lookup_min,
|
||||||
# TODO allow draft-model specific load config.
|
# TODO allow draft-model specific load config.
|
||||||
#load_config=self.load_config,
|
#load_config=self.load_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -57,13 +57,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
draft_worker_kwargs,
|
draft_worker_kwargs,
|
||||||
) -> "SpecDecodeWorker":
|
) -> "SpecDecodeWorker":
|
||||||
|
|
||||||
if "ngram_prompt_lookup_max" in draft_worker_kwargs:
|
|
||||||
ngram_prompt_lookup_max = (
|
ngram_prompt_lookup_max = (
|
||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_max"))
|
||||||
ngram_prompt_lookup_min = (
|
ngram_prompt_lookup_min = (
|
||||||
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
draft_worker_kwargs.pop("ngram_prompt_lookup_min"))
|
||||||
else:
|
|
||||||
ngram_prompt_lookup_max = 0
|
|
||||||
|
|
||||||
if ngram_prompt_lookup_max > 0:
|
if ngram_prompt_lookup_max > 0:
|
||||||
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
proposer_worker = NGramWorker(**draft_worker_kwargs)
|
||||||
@ -72,6 +69,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
else:
|
else:
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
|
||||||
|
logger.info("Configuring SpecDecodeWorker with proposer=%s",
|
||||||
|
type(proposer_worker))
|
||||||
|
|
||||||
return SpecDecodeWorker(
|
return SpecDecodeWorker(
|
||||||
proposer_worker,
|
proposer_worker,
|
||||||
scorer_worker,
|
scorer_worker,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user