From ccb63a8245bceb9e6ba260eeef41b54ca8bdb370 Mon Sep 17 00:00:00 2001 From: Kuntai Du Date: Tue, 14 May 2024 05:34:33 -0700 Subject: [PATCH] [Core][Hash][Automatic Prefix caching] Accelerating the hashing function by avoiding deep copies (#4696) --- benchmarks/overheads/benchmark_hashing.py | 63 +++++++++++++++++++++++ vllm/sequence.py | 16 +++++- 2 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 benchmarks/overheads/benchmark_hashing.py diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py new file mode 100644 index 00000000..c846e47d --- /dev/null +++ b/benchmarks/overheads/benchmark_hashing.py @@ -0,0 +1,63 @@ +import argparse +import cProfile +import pstats + +from vllm import LLM, SamplingParams + +# A very long prompt, total number of tokens is about 15k. +LONG_PROMPT = ["You are an expert in large language models, aren't you?" + ] * 1000 +LONG_PROMPT = ' '.join(LONG_PROMPT) + + +def main(args): + llm = LLM( + model=args.model, + enforce_eager=True, + enable_prefix_caching=True, + tensor_parallel_size=args.tensor_parallel_size, + use_v2_block_manager=args.use_v2_block_manager, + ) + + sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + profiler = cProfile.Profile() + + print("------warm up------") + for i in range(3): + output = llm.generate(LONG_PROMPT, sampling_params) + print(output[0].outputs[0].text) + + print("------start generating------") + for i in range(3): + profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', + globals(), locals()) + + # analyze the runtime of hashing function + stats = pstats.Stats(profiler) + stats.sort_stats('cumulative') + total_time = 0 + total_calls = 0 + for func in stats.stats: + if 'hash_of_block' in func[2]: + total_time = stats.stats[func][3] + total_calls = stats.stats[func][0] + percentage = (total_time / stats.total_tt) * 100 + print(f"Hashing took {total_time:.2f} seconds," + f"{percentage:.2f}% of the total runtime.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Benchmark the performance of hashing function in' + 'automatic prefix caching.') + parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') + parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) + parser.add_argument('--output-len', type=int, default=10) + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='enable prefix caching') + parser.add_argument('--use-v2-block-manager', + action='store_true', + help='Use BlockSpaceMangerV2') + args = parser.parse_args() + main(args) diff --git a/vllm/sequence.py b/vllm/sequence.py index 46ac33b7..12e930c2 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -121,6 +121,7 @@ class SequenceData: output_token_ids = [] self.prompt_token_ids = prompt_token_ids + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self.output_token_ids = output_token_ids self.cumulative_logprob = 0.0 # The number of tokens that are computed (that run against the model). @@ -143,6 +144,17 @@ class SequenceData: def get_token_ids(self) -> List[int]: return self.prompt_token_ids + self.output_token_ids + def get_prefix_token_ids( + self, num_tokens: int + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = len(self.prompt_token_ids) + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self.output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + def get_num_computed_tokens(self) -> int: """Return the number of prefill tokens that are already computed.""" return self._num_computed_tokens @@ -253,8 +265,8 @@ class Sequence: # TODO: The current hashing function is O(L^2). We should optimize # this in the future. num_tokens = self.num_hashed_tokens_of_block(logical_idx) - return hash( - (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id)) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) def num_hashed_tokens_of_block(self, logical_idx: int): return logical_idx * self.block_size + self.block_size