diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index cee5e7e9..2eeba904 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -76,7 +76,7 @@ steps: #mirror_hardwares: [amd] commands: - bash ../.buildkite/download-images.sh - - pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py + - pytest -v -s models --ignore=models/test_llava.py - label: Llava Test #mirror_hardwares: [amd] diff --git a/tests/conftest.py b/tests/conftest.py index 67132691..1f2ad1cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -272,6 +272,68 @@ class HfRunner: all_logprobs.append(seq_logprobs) return all_logprobs + def generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + all_logprobs = [] + all_output_ids = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if getattr(self.model.get_output_embeddings(), "bias", + None) is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def __del__(self): del self.model cleanup() diff --git a/tests/models/test_big_models.py b/tests/models/test_big_models.py index 3dde498b..c02204f1 100644 --- a/tests/models/test_big_models.py +++ b/tests/models/test_big_models.py @@ -8,7 +8,7 @@ import pytest MODELS = [ "meta-llama/Llama-2-7b-hf", - # "mistralai/Mistral-7B-v0.1", # Broken + # "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py # "Deci/DeciLM-7b", # Broken # "tiiuae/falcon-7b", # Broken "EleutherAI/gpt-j-6b", diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 7aeff3a9..33d28da8 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -4,6 +4,8 @@ Run `pytest tests/models/test_mistral.py`. """ import pytest +from tests.models.utils import check_logprobs_close + MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", ] @@ -11,30 +13,31 @@ MODELS = [ @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.skip( - "Two problems: 1. Failing correctness tests. 2. RuntimeError: expected " - "scalar type BFloat16 but found Half (only in CI).") +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) def test_models( hf_runner, vllm_runner, - example_long_prompts, + example_prompts, model: str, dtype: str, max_tokens: int, + num_logprobs: int, ) -> None: + # TODO(sang): Sliding window should be tested separately. hf_model = hf_runner(model, dtype=dtype) - hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens) + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) del hf_model vllm_model = vllm_runner(model, dtype=dtype) - vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) del vllm_model - - for i in range(len(example_long_prompts)): - hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = vllm_outputs[i] - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 857d70fa..f41e0f30 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -109,7 +109,7 @@ class RotaryEmbedding(nn.Module): key_pass = key[..., self.rotary_dim:] self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device) + positions.device, dtype=query.dtype) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -143,7 +143,8 @@ class RotaryEmbedding(nn.Module): key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - self.cos_sin_cache = self.cos_sin_cache.to(positions.device) + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. if offsets is not None: