Fix slot mapping
This commit is contained in:
parent
8290fce47d
commit
4b1ac23f53
@ -14,7 +14,7 @@ class InputMetadata:
|
||||
block_tables: torch.Tensor,
|
||||
) -> None:
|
||||
self.prompt_lens = prompt_lens
|
||||
self.prompt_block_table = slot_mapping
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.max_context_len = max_context_len
|
||||
self.block_tables = block_tables
|
||||
|
||||
@ -83,12 +83,20 @@ class Worker:
|
||||
generation_seq_ids = sorted(generation_tokens.keys())
|
||||
for seq_id in generation_seq_ids:
|
||||
input_tokens.append(generation_tokens[seq_id])
|
||||
input_positions.append(context_lens[seq_id] - 1)
|
||||
generation_block_tables.append(block_tables[seq_id])
|
||||
position_id = context_lens[seq_id] - 1
|
||||
input_positions.append(position_id)
|
||||
|
||||
block_table = block_tables[seq_id]
|
||||
generation_block_tables.append(block_table)
|
||||
|
||||
max_context_len = max(max_context_len, context_lens[seq_id])
|
||||
max_num_blocks_per_seq = max(
|
||||
max_num_blocks_per_seq, len(block_tables[seq_id]))
|
||||
max_num_blocks_per_seq, len(block_table))
|
||||
|
||||
block_number = block_table[position_id // self.block_size]
|
||||
block_offset = position_id % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Optimization: Pad the input length to be a multiple of 8.
|
||||
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
||||
@ -105,9 +113,11 @@ class Worker:
|
||||
context_lens_tensor = torch.tensor(
|
||||
[context_lens[seq_id] for seq_id in generation_seq_ids],
|
||||
dtype=torch.int, device=self.device)
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
[_pad_to_max(block_table) for block_table in generation_block_tables],
|
||||
dtype=int, device=self.device)
|
||||
padded_block_tables, dtype=int, device=self.device)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
prompt_lens=prompt_lens,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user