From 309aaef8255fb832bf674c6ed7d9d84211629421 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 24 Jul 2024 22:33:56 -0700 Subject: [PATCH] [Bugfix] Fix decode tokens w. CUDA graph (#6757) --- tests/worker/test_model_runner.py | 1 + vllm/attention/backends/flash_attn.py | 12 ++++++++++-- vllm/attention/backends/flashinfer.py | 11 ++++++++++- vllm/attention/backends/utils.py | 11 ++++++++++- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b5742c43..4a0e2b41 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -193,6 +193,7 @@ def test_prepare_decode_cuda_graph(batch_size): for _ in range(expected_bs - len(seq_lens)): seq_lens.append(1) assert attn_metadata.seq_lens == seq_lens + assert attn_metadata.num_decode_tokens == len(seq_lens) start_idx = 0 start_loc = [start_idx] for _ in context_lens: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 949bd973..7d7aff9d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -272,7 +272,15 @@ class FlashAttentionMetadataBuilder( def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): - """Build attention metadata with on-device tensors.""" + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -297,7 +305,7 @@ class FlashAttentionMetadataBuilder( if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 97463043..83a420d7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -320,6 +320,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -334,7 +343,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size]. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 5877712b..dcd10ed4 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -149,6 +149,15 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ for inter_data in self.input_builder.inter_data_list: self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled) @@ -173,7 +182,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size + cuda_graph_pad_size + num_decode_tokens = batch_size # The shape of graph_block_tables is # [max batch size, max context len // block size].