[FIX] Fix styles in automatic prefix caching & add a automatic prefix caching benchmark (#3158)
This commit is contained in:
parent
d65fac2738
commit
996d095c54
59
benchmarks/benchmark_prefix_caching.py
Normal file
59
benchmarks/benchmark_prefix_caching.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix(llm=None, sampling_params=None, prompts=None, prefix_len=None):
|
||||||
|
start_time = time.time()
|
||||||
|
# whether use Prefix
|
||||||
|
if prefix_len != None:
|
||||||
|
# start inference
|
||||||
|
llm.generate(prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
prefix_pos=prefix_len)
|
||||||
|
else:
|
||||||
|
llm.generate(prompts, sampling_params=sampling_params)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"cost time {end_time - start_time}")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
llm = LLM(model="baichuan-inc/Baichuan2-13B-Chat",
|
||||||
|
tokenizer_mode='auto',
|
||||||
|
trust_remote_code=True,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=args.enable_prefix_caching)
|
||||||
|
|
||||||
|
num_prompts = 100
|
||||||
|
prompts = [PROMPT] * num_prompts
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=100)
|
||||||
|
|
||||||
|
print("------warm up------")
|
||||||
|
test_prefix(
|
||||||
|
llm=llm,
|
||||||
|
prompts=prompts[:1],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("------start generating------")
|
||||||
|
test_prefix(
|
||||||
|
llm=llm,
|
||||||
|
prompts=prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Benchmark the performance with or without automatic '
|
||||||
|
'prefix caching.')
|
||||||
|
parser.add_argument('--enable-prefix-caching',
|
||||||
|
action='store_true',
|
||||||
|
help='enable prefix caching')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@ -303,7 +303,10 @@ if __name__ == "__main__":
|
|||||||
default="cuda",
|
default="cuda",
|
||||||
choices=["cuda"],
|
choices=["cuda"],
|
||||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||||
parser.add_argument("--enable_prefix_caching", action='store_true')
|
parser.add_argument(
|
||||||
|
"--enable-prefix-caching",
|
||||||
|
action='store_true',
|
||||||
|
help="enable automatic prefix caching for vLLM backend.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|||||||
@ -236,13 +236,6 @@ class BlockSpaceManager:
|
|||||||
token_ids_len = len(seq.data.get_token_ids())
|
token_ids_len = len(seq.data.get_token_ids())
|
||||||
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||||
|
|
||||||
def _is_last_block(
|
|
||||||
self,
|
|
||||||
seq: Sequence,
|
|
||||||
index: int,
|
|
||||||
) -> bool:
|
|
||||||
return index == len(seq.logical_token_blocks) - 1
|
|
||||||
|
|
||||||
def _maybe_promote_last_block(
|
def _maybe_promote_last_block(
|
||||||
self,
|
self,
|
||||||
seq: Sequence,
|
seq: Sequence,
|
||||||
@ -436,7 +429,7 @@ class BlockSpaceManager:
|
|||||||
def compute_last_full_block_in_seq(self, seq: Sequence):
|
def compute_last_full_block_in_seq(self, seq: Sequence):
|
||||||
if seq.seq_id not in self.block_tables:
|
if seq.seq_id not in self.block_tables:
|
||||||
return
|
return
|
||||||
max_full_block = seq.get_len() // seq.block_size - 1
|
max_full_block = seq.get_len() // self.block_size - 1
|
||||||
block_table = self.block_tables[seq.seq_id]
|
block_table = self.block_tables[seq.seq_id]
|
||||||
if max_full_block == -1:
|
if max_full_block == -1:
|
||||||
return
|
return
|
||||||
@ -451,9 +444,9 @@ class BlockSpaceManager:
|
|||||||
return [b.block_number for b in block_table[:block_idx + 1]]
|
return [b.block_number for b in block_table[:block_idx + 1]]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Can return non-empty result only with prefix caching enabled.
|
|
||||||
def get_common_computed_block_ids(self,
|
def get_common_computed_block_ids(self,
|
||||||
seq_group: SequenceGroup) -> List[int]:
|
seq_group: SequenceGroup) -> List[int]:
|
||||||
|
# Can return non-empty result only with prefix caching enabled.
|
||||||
if not self.enable_caching:
|
if not self.enable_caching:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -463,9 +456,9 @@ class BlockSpaceManager:
|
|||||||
]
|
]
|
||||||
return commonprefix([ids for ids in ids_list if ids != []])
|
return commonprefix([ids for ids in ids_list if ids != []])
|
||||||
|
|
||||||
# We only mark the last full block because with prefix caching,
|
|
||||||
# all blocks until the marked one are guaranteed to be computed.
|
|
||||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||||
|
# NOTE: We only mark the last full block because with prefix caching,
|
||||||
|
# all blocks until the marked one are guaranteed to be computed.
|
||||||
if self.enable_caching:
|
if self.enable_caching:
|
||||||
for seq in seq_group.seqs_dict.values():
|
for seq in seq_group.seqs_dict.values():
|
||||||
self.compute_last_full_block_in_seq(seq)
|
self.compute_last_full_block_in_seq(seq)
|
||||||
|
|||||||
@ -160,10 +160,10 @@ class Sequence:
|
|||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|
||||||
# TODO The current hashing function is O(L^2). We should optimize this in
|
|
||||||
# the future.
|
|
||||||
def hash_of_block(self, logical_idx: int) -> int:
|
def hash_of_block(self, logical_idx: int) -> int:
|
||||||
# Compute the number of tokens in the sequence
|
# Compute the number of tokens in the sequence
|
||||||
|
# TODO: The current hashing function is O(L^2). We should optimize
|
||||||
|
# this in the future.
|
||||||
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||||
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
||||||
|
|
||||||
@ -308,10 +308,6 @@ class SequenceGroup:
|
|||||||
# We use the prompt of an arbitrary sequence.
|
# We use the prompt of an arbitrary sequence.
|
||||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||||
|
|
||||||
@property
|
|
||||||
def block_size(self) -> int:
|
|
||||||
return next(iter(self.seqs_dict.values())).block_size
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_int_id(self) -> int:
|
def lora_int_id(self) -> int:
|
||||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user