[spec decode] [4/N] Move update_flash_attn_metadata to attn backend (#7571)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
855866caa9
commit
f366f6339b
@ -75,6 +75,9 @@ class AttentionBackend(ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def advance_step(self, num_seqs: int, num_queries: int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionMetadata:
|
||||
|
||||
@ -297,6 +297,51 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self, num_seqs: int, num_queries: int):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# GPU in-place update is currently called separately through
|
||||
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
|
||||
# this logic to the backend.
|
||||
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
|
||||
class FlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||
|
||||
@ -97,38 +97,6 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
self.flashinfer_prefill_workspace_buffer = None
|
||||
self.flashinfer_prefill_wrapper = None
|
||||
|
||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
||||
num_queries):
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert attn_metadata.use_cuda_graph
|
||||
|
||||
assert attn_metadata.num_prefills == 0
|
||||
assert attn_metadata.num_prefill_tokens == 0
|
||||
assert attn_metadata.num_decode_tokens == num_seqs
|
||||
assert attn_metadata.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert len(attn_metadata.seq_lens) == num_seqs
|
||||
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert attn_metadata.max_query_len == 1
|
||||
assert attn_metadata.max_prefill_seq_len == 0
|
||||
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
|
||||
|
||||
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
|
||||
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert attn_metadata.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
attn_metadata.seq_lens[i] += 1
|
||||
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
@ -166,7 +134,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
# Update attn_metadata
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
|
||||
attn_metadata.advance_step(num_seqs, num_queries)
|
||||
|
||||
# Update GPU tensors
|
||||
ops.advance_step(num_seqs=num_seqs,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user