From 996d095c541e1cd67f0a7ec2579bc3bb0a435494 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 3 Mar 2024 14:37:18 -0800 Subject: [PATCH] [FIX] Fix styles in automatic prefix caching & add a automatic prefix caching benchmark (#3158) --- benchmarks/benchmark_prefix_caching.py | 59 ++++++++++++++++++++++++++ benchmarks/benchmark_throughput.py | 5 ++- vllm/core/block_manager.py | 15 ++----- vllm/sequence.py | 8 +--- 4 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 benchmarks/benchmark_prefix_caching.py diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py new file mode 100644 index 00000000..c43bd9c3 --- /dev/null +++ b/benchmarks/benchmark_prefix_caching.py @@ -0,0 +1,59 @@ +import argparse +import time + +from vllm import LLM +from vllm import SamplingParams + +PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" + + +def test_prefix(llm=None, sampling_params=None, prompts=None, prefix_len=None): + start_time = time.time() + # whether use Prefix + if prefix_len != None: + # start inference + llm.generate(prompts, + sampling_params=sampling_params, + prefix_pos=prefix_len) + else: + llm.generate(prompts, sampling_params=sampling_params) + + end_time = time.time() + print(f"cost time {end_time - start_time}") + + +def main(args): + llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat", + tokenizer_mode='auto', + trust_remote_code=True, + enforce_eager=True, + enable_prefix_caching=args.enable_prefix_caching) + + num_prompts = 100 + prompts = [PROMPT] * num_prompts + sampling_params = SamplingParams(temperature=0, max_tokens=100) + + print("------warm up------") + test_prefix( + llm=llm, + prompts=prompts[:1], + sampling_params=sampling_params, + ) + + print("------start generating------") + test_prefix( + llm=llm, + prompts=prompts, + sampling_params=sampling_params, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark the performance with or without automatic ' + 'prefix caching.') + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + args = parser.parse_args() + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 51c1a654..1f0bfe06 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -303,7 +303,10 @@ if __name__ == "__main__": default="cuda", choices=["cuda"], help='device type for vLLM execution, supporting CUDA only currently.') - parser.add_argument("--enable_prefix_caching", action='store_true') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 08d519ab..daf83827 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -236,13 +236,6 @@ class BlockSpaceManager: token_ids_len = len(seq.data.get_token_ids()) return token_ids_len > 0 and token_ids_len % seq.block_size == 0 - def _is_last_block( - self, - seq: Sequence, - index: int, - ) -> bool: - return index == len(seq.logical_token_blocks) - 1 - def _maybe_promote_last_block( self, seq: Sequence, @@ -436,7 +429,7 @@ class BlockSpaceManager: def compute_last_full_block_in_seq(self, seq: Sequence): if seq.seq_id not in self.block_tables: return - max_full_block = seq.get_len() // seq.block_size - 1 + max_full_block = seq.get_len() // self.block_size - 1 block_table = self.block_tables[seq.seq_id] if max_full_block == -1: return @@ -451,9 +444,9 @@ class BlockSpaceManager: return [b.block_number for b in block_table[:block_idx + 1]] return [] - # Can return non-empty result only with prefix caching enabled. def get_common_computed_block_ids(self, seq_group: SequenceGroup) -> List[int]: + # Can return non-empty result only with prefix caching enabled. if not self.enable_caching: return [] @@ -463,9 +456,9 @@ class BlockSpaceManager: ] return commonprefix([ids for ids in ids_list if ids != []]) - # We only mark the last full block because with prefix caching, - # all blocks until the marked one are guaranteed to be computed. def mark_blocks_as_computed(self, seq_group: SequenceGroup): + # NOTE: We only mark the last full block because with prefix caching, + # all blocks until the marked one are guaranteed to be computed. if self.enable_caching: for seq in seq_group.seqs_dict.values(): self.compute_last_full_block_in_seq(seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index 12296003..04a9a90a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -160,10 +160,10 @@ class Sequence: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - # TODO The current hashing function is O(L^2). We should optimize this in - # the future. def hash_of_block(self, logical_idx: int) -> int: # Compute the number of tokens in the sequence + # TODO: The current hashing function is O(L^2). We should optimize + # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) return hash(tuple(self.data.get_token_ids()[0:num_tokens])) @@ -308,10 +308,6 @@ class SequenceGroup: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids - @property - def block_size(self) -> int: - return next(iter(self.seqs_dict.values())).block_size - @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0