From 473e7b3606e9b95b39c7da46cce00a33c069dc00 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 14 Oct 2024 15:02:06 -0700 Subject: [PATCH] [TPU] Fix TPU SMEM OOM by Pallas paged attention kernel (#9350) --- vllm/attention/backends/pallas.py | 99 ++++++++++++++++++++++++------- vllm/worker/tpu_model_runner.py | 9 +++ 2 files changed, 88 insertions(+), 20 deletions(-) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 86716602..56d3d3b4 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -208,35 +208,54 @@ class PallasAttentionBackendImpl(AttentionImpl): else: # Decoding run. assert kv_cache[0].numel() > 0 - + query = query.squeeze(dim=1) pages_per_compute_block = 16 # TODO(woosuk): Tune this value. - if self.megacore_mode == "batch" and batch_size % 2 != 0: - megacore_mode = None - else: - megacore_mode = self.megacore_mode - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query.squeeze(dim=1), + assert attn_metadata.block_tables is not None + assert attn_metadata.context_lens is not None + # NOTE(woosuk): The PagedAttention Pallas kernel stores the entire + # block table in SMEM. Therefore, if the block table is too large, + # the kernel compilation will fail. To avoid this, we split the + # batch dimension into smaller chunks and run the kernel multiple + # times. + MAX_SMEM_USAGE = 512 * 1024 + size_per_seq = 4 * attn_metadata.block_tables.shape[1] + max_num_seq = MAX_SMEM_USAGE // size_per_seq + + if batch_size <= max_num_seq: + output = paged_attention( + query, key_cache, value_cache, attn_metadata.context_lens, attn_metadata.block_tables, pages_per_compute_block, - megacore_mode=megacore_mode, + self.megacore_mode, ) else: - output = torch.ops.xla.paged_attention( - query.squeeze(dim=1), - key_cache, - value_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - pages_per_compute_block, - ) + chunk_size = max_num_seq + # Make sure the chunk size is a multiple of 2. + chunk_size = chunk_size // 2 * 2 + num_chunks = (batch_size + chunk_size - 1) // chunk_size + + output = torch.empty_like(query) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = chunk_start + chunk_size + # NOTE(woosuk): We skip this line because it causes Dynamo + # compilation error. Instead, we rely on the slice operation + # to handle the out-of-bound case. + # chunk_end = min(chunk_end, batch_size) + chunk_output = paged_attention( + query[chunk_start:chunk_end], + key_cache, + value_cache, + attn_metadata.context_lens[chunk_start:chunk_end], + attn_metadata.block_tables[chunk_start:chunk_end], + pages_per_compute_block, + self.megacore_mode, + ) + output[chunk_start:chunk_end] = chunk_output # Reshape the output tensor. return output.reshape(batch_size, seq_len, hidden_size) @@ -258,3 +277,43 @@ def write_to_kv_cache( value_cache = value_cache.flatten(0, 2) key_cache.index_copy_(0, slot_mapping, key) value_cache.index_copy_(0, slot_mapping, value) + + +def paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + pages_per_compute_block: int, + megacore_mode: Optional[str], +) -> torch.Tensor: + batch_size = query.shape[0] + if megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + ) + return output diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index c13e95f6..f7e5f660 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -123,6 +123,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): ) self.cached_step_outputs: List[torch.Tensor] = [] + smem_size = 512 * 1024 + block_table_size = 4 * self.block_tables.size + if block_table_size >= smem_size: + logger.warning( + "The max_model_len (%d) is too large. This may degrade the " + "performance due to the insufficient smem size. Consider " + "setting --max-model-len to a smaller value.", + self.model_config.max_model_len) + def load_model(self) -> None: self.device = self.device_config.device