[Bugfix] Fix decode tokens w. CUDA graph (#6757)

This commit is contained in:
Cody Yu 2024-07-24 22:33:56 -07:00 committed by GitHub
parent 9e169a4c61
commit 309aaef825
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 4 deletions

View File

@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size):
for _ in range(expected_bs - len(seq_lens)):
seq_lens.append(1)
assert attn_metadata.seq_lens == seq_lens
assert attn_metadata.num_decode_tokens == len(seq_lens)
start_idx = 0
start_loc = [start_idx]
for _ in context_lens:

View File

@ -272,7 +272,15 @@ class FlashAttentionMetadataBuilder(
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors."""
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
@ -297,7 +305,7 @@ class FlashAttentionMetadataBuilder(
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].

View File

@ -320,6 +320,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
@ -334,7 +343,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].

View File

@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
@ -173,7 +182,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size
num_decode_tokens = batch_size
# The shape of graph_block_tables is
# [max batch size, max context len // block size].