[TPU] Fix TPU SMEM OOM by Pallas paged attention kernel (#9350)
This commit is contained in:
parent
fd47e57f4b
commit
473e7b3606
@ -208,35 +208,54 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
assert kv_cache[0].numel() > 0
|
assert kv_cache[0].numel() > 0
|
||||||
|
query = query.squeeze(dim=1)
|
||||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||||
if self.megacore_mode == "batch" and batch_size % 2 != 0:
|
|
||||||
megacore_mode = None
|
|
||||||
else:
|
|
||||||
megacore_mode = self.megacore_mode
|
|
||||||
|
|
||||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
assert attn_metadata.block_tables is not None
|
||||||
# "xla::paged_attention() Expected a value of type 'str' for
|
assert attn_metadata.context_lens is not None
|
||||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||||
if megacore_mode is not None:
|
# block table in SMEM. Therefore, if the block table is too large,
|
||||||
output = torch.ops.xla.paged_attention(
|
# the kernel compilation will fail. To avoid this, we split the
|
||||||
query.squeeze(dim=1),
|
# batch dimension into smaller chunks and run the kernel multiple
|
||||||
|
# times.
|
||||||
|
MAX_SMEM_USAGE = 512 * 1024
|
||||||
|
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||||
|
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||||
|
|
||||||
|
if batch_size <= max_num_seq:
|
||||||
|
output = paged_attention(
|
||||||
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.context_lens,
|
attn_metadata.context_lens,
|
||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
megacore_mode=megacore_mode,
|
self.megacore_mode,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
output = torch.ops.xla.paged_attention(
|
chunk_size = max_num_seq
|
||||||
query.squeeze(dim=1),
|
# Make sure the chunk size is a multiple of 2.
|
||||||
|
chunk_size = chunk_size // 2 * 2
|
||||||
|
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
chunk_start = chunk_idx * chunk_size
|
||||||
|
chunk_end = chunk_start + chunk_size
|
||||||
|
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||||
|
# compilation error. Instead, we rely on the slice operation
|
||||||
|
# to handle the out-of-bound case.
|
||||||
|
# chunk_end = min(chunk_end, batch_size)
|
||||||
|
chunk_output = paged_attention(
|
||||||
|
query[chunk_start:chunk_end],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.context_lens,
|
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
|
self.megacore_mode,
|
||||||
)
|
)
|
||||||
|
output[chunk_start:chunk_end] = chunk_output
|
||||||
|
|
||||||
# Reshape the output tensor.
|
# Reshape the output tensor.
|
||||||
return output.reshape(batch_size, seq_len, hidden_size)
|
return output.reshape(batch_size, seq_len, hidden_size)
|
||||||
@ -258,3 +277,43 @@ def write_to_kv_cache(
|
|||||||
value_cache = value_cache.flatten(0, 2)
|
value_cache = value_cache.flatten(0, 2)
|
||||||
key_cache.index_copy_(0, slot_mapping, key)
|
key_cache.index_copy_(0, slot_mapping, key)
|
||||||
value_cache.index_copy_(0, slot_mapping, value)
|
value_cache.index_copy_(0, slot_mapping, value)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
pages_per_compute_block: int,
|
||||||
|
megacore_mode: Optional[str],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||||
|
megacore_mode = None
|
||||||
|
else:
|
||||||
|
megacore_mode = megacore_mode
|
||||||
|
|
||||||
|
# NOTE(woosuk): A temporary workaround to avoid the error:
|
||||||
|
# "xla::paged_attention() Expected a value of type 'str' for
|
||||||
|
# argument 'megacore_mode' but instead found type 'NoneType'."
|
||||||
|
if megacore_mode is not None:
|
||||||
|
output = torch.ops.xla.paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
megacore_mode=megacore_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = torch.ops.xla.paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|||||||
@ -123,6 +123,15 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
)
|
)
|
||||||
self.cached_step_outputs: List[torch.Tensor] = []
|
self.cached_step_outputs: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
smem_size = 512 * 1024
|
||||||
|
block_table_size = 4 * self.block_tables.size
|
||||||
|
if block_table_size >= smem_size:
|
||||||
|
logger.warning(
|
||||||
|
"The max_model_len (%d) is too large. This may degrade the "
|
||||||
|
"performance due to the insufficient smem size. Consider "
|
||||||
|
"setting --max-model-len to a smaller value.",
|
||||||
|
self.model_config.max_model_len)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.device = self.device_config.device
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user