Fix slot mapping

This commit is contained in:
Woosuk Kwon 2023-02-23 00:10:07 +00:00
parent 8290fce47d
commit 4b1ac23f53
2 changed files with 16 additions and 6 deletions

View File

@ -14,7 +14,7 @@ class InputMetadata:
block_tables: torch.Tensor,
) -> None:
self.prompt_lens = prompt_lens
self.prompt_block_table = slot_mapping
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables

View File

@ -83,12 +83,20 @@ class Worker:
generation_seq_ids = sorted(generation_tokens.keys())
for seq_id in generation_seq_ids:
input_tokens.append(generation_tokens[seq_id])
input_positions.append(context_lens[seq_id] - 1)
generation_block_tables.append(block_tables[seq_id])
position_id = context_lens[seq_id] - 1
input_positions.append(position_id)
block_table = block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_lens[seq_id])
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_tables[seq_id]))
max_num_blocks_per_seq, len(block_table))
block_number = block_table[position_id // self.block_size]
block_offset = position_id % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
@ -105,9 +113,11 @@ class Worker:
context_lens_tensor = torch.tensor(
[context_lens[seq_id] for seq_id in generation_seq_ids],
dtype=torch.int, device=self.device)
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
[_pad_to_max(block_table) for block_table in generation_block_tables],
dtype=int, device=self.device)
padded_block_tables, dtype=int, device=self.device)
input_metadata = InputMetadata(
prompt_lens=prompt_lens,