diff --git a/tests/conftest.py b/tests/conftest.py index af04cfbb..d904058d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple import pytest import torch +import torch.nn.functional as F from PIL import Image from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlavaConfig, LlavaForConditionalGeneration) @@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel -from vllm.inputs import PromptInputs +from vllm.inputs import TextPrompt from vllm.logger import init_logger -from vllm.sequence import MultiModalData +from vllm.sequence import MultiModalData, SampleLogprobs logger = init_logger(__name__) @@ -188,10 +189,11 @@ class HfRunner: prompts: List[str], images: Optional[List[Image.Image]] = None, **kwargs, - ) -> List[Tuple[List[int], str]]: - outputs: List[Tuple[List[int], str]] = [] + ) -> List[Tuple[List[List[int]], List[str]]]: if images: assert len(prompts) == len(images) + + outputs: List[Tuple[List[List[int]], List[str]]] = [] for i, prompt in enumerate(prompts): processor_kwargs: Dict[str, Any] = { "text": prompt, @@ -201,17 +203,13 @@ class HfRunner: processor_kwargs["images"] = images[i] inputs = self.processor(**processor_kwargs) - inputs = { - key: value.cuda() if value is not None else None - for key, value in inputs.items() - } output_ids = self.model.generate( - **inputs, + **inputs.to("cuda"), use_cache=True, **kwargs, ) - output_str = self.tokenizer.batch_decode( + output_str = self.processor.batch_decode( output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, @@ -224,23 +222,22 @@ class HfRunner: self, prompts: List[str], max_tokens: int, - images: Optional["torch.Tensor"] = None, + images: Optional[List[Image.Image]] = None, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, images=images) - for i in range(len(outputs)): - output_ids, output_str = outputs[i] - outputs[i] = (output_ids[0], output_str[0]) - return outputs + + return [(output_ids[0], output_str[0]) + for output_ids, output_str in outputs] def generate_beam_search( self, prompts: List[str], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[List[int]], List[str]]]: outputs = self.generate(prompts, do_sample=False, max_new_tokens=max_tokens, @@ -282,9 +279,7 @@ class HfRunner: if self.model.get_output_embeddings().bias is not None: logits += self.model.get_output_embeddings( ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) all_logprobs.append(seq_logprobs) return all_logprobs @@ -294,10 +289,10 @@ class HfRunner: prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: - all_logprobs = [] - all_output_ids = [] - all_output_strs = [] + ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] for prompt in prompts: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids @@ -310,7 +305,7 @@ class HfRunner: return_dict_in_generate=True, ) - seq_logprobs = [] + seq_logprobs: List[torch.Tensor] = [] for _, hidden_states in enumerate(output.hidden_states): last_hidden_states = hidden_states[-1][0] logits = torch.matmul( @@ -321,13 +316,11 @@ class HfRunner: None) is not None: logits += self.model.get_output_embeddings( ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) # convert to dict - seq_logprobs_lst = [] + seq_logprobs_lst: List[Dict[int, float]] = [] for tok_idx, tok_logprobs in enumerate(seq_logprobs): # drop prompt logprobs if tok_idx == 0: @@ -372,13 +365,13 @@ class VllmRunner: tokenizer_name: Optional[str] = None, # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. - max_model_len=1024, + max_model_len: int = 1024, dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, - swap_space=4, + swap_space: int = 4, **kwargs, ) -> None: self.model = LLM( @@ -399,32 +392,31 @@ class VllmRunner: self, prompts: List[str], sampling_params: SamplingParams, - images: Optional["torch.Tensor"] = None, - ) -> List[Tuple[List[int], str]]: + images: Optional[torch.Tensor] = None, + ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: - assert len(prompts) == images.shape[0] + assert len(prompts) == len(images) - prompt_inputs: List[PromptInputs] = [] + prompt_inputs: List[TextPrompt] = [] for i, prompt in enumerate(prompts): - image = None if images is None else images[i:i + 1] - mm_data = None if image is None else MultiModalData( - type=MultiModalData.Type.IMAGE, - data=image, - ) + prompt = TextPrompt(prompt=prompt) + if images is not None: + prompt["multi_modal_data"] = MultiModalData( + type=MultiModalData.Type.IMAGE, + data=images[i:i + 1], + ) - prompt_inputs.append({ - "prompt": prompt, - "multi_modal_data": mm_data, - }) + prompt_inputs.append(prompt) req_outputs = self.model.generate(prompt_inputs, sampling_params=sampling_params) - outputs = [] + + outputs: List[Tuple[List[List[int]], List[str]]] = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids - req_sample_output_ids = [] - req_sample_output_strs = [] + req_sample_output_ids: List[List[int]] = [] + req_sample_output_strs: List[str] = [] for sample in req_output.outputs: output_str = sample.text output_ids = sample.token_ids @@ -437,12 +429,12 @@ class VllmRunner: self, prompts: List[str], sampling_params: SamplingParams, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None req_outputs = self.model.generate(prompts, sampling_params=sampling_params) - outputs = [] + outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = [] for req_output in req_outputs: for sample in req_output.outputs: output_str = sample.text @@ -467,7 +459,7 @@ class VllmRunner: prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs) @@ -481,7 +473,7 @@ class VllmRunner: prompts: List[str], beam_width: int, max_tokens: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[List[int]], List[str]]]: beam_search_params = SamplingParams(n=beam_width, use_beam_search=True, temperature=0.0,