[spec decode] [4/N] Move update_flash_attn_metadata to attn backend (#7571)
Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
parent
855866caa9
commit
f366f6339b
@ -75,6 +75,9 @@ class AttentionBackend(ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def advance_step(self, num_seqs: int, num_queries: int):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AttentionMetadata:
|
class AttentionMetadata:
|
||||||
|
|||||||
@ -297,6 +297,51 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
)
|
)
|
||||||
return self._cached_decode_metadata
|
return self._cached_decode_metadata
|
||||||
|
|
||||||
|
def advance_step(self, num_seqs: int, num_queries: int):
|
||||||
|
"""
|
||||||
|
Update metadata in-place to advance one decode step.
|
||||||
|
"""
|
||||||
|
# GPU in-place update is currently called separately through
|
||||||
|
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
|
||||||
|
# this logic to the backend.
|
||||||
|
|
||||||
|
# When using cudagraph, the num_seqs is padded to the next captured
|
||||||
|
# batch sized, but num_queries tracks the actual number of requests in
|
||||||
|
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||||
|
if num_seqs != num_queries:
|
||||||
|
assert num_seqs > num_queries
|
||||||
|
assert self.use_cuda_graph
|
||||||
|
|
||||||
|
assert self.num_prefills == 0
|
||||||
|
assert self.num_prefill_tokens == 0
|
||||||
|
assert self.num_decode_tokens == num_seqs
|
||||||
|
assert self.slot_mapping.shape == (num_seqs, )
|
||||||
|
|
||||||
|
assert self.seq_lens is not None
|
||||||
|
assert len(self.seq_lens) == num_seqs
|
||||||
|
assert self.seq_lens_tensor is not None
|
||||||
|
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||||
|
assert self.max_query_len == 1
|
||||||
|
assert self.max_prefill_seq_len == 0
|
||||||
|
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||||
|
|
||||||
|
assert self.query_start_loc is not None
|
||||||
|
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||||
|
assert self.seq_start_loc is not None
|
||||||
|
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||||
|
|
||||||
|
assert self.context_lens_tensor is not None
|
||||||
|
assert self.context_lens_tensor.shape == (num_queries, )
|
||||||
|
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.block_tables.shape[0] == num_seqs
|
||||||
|
|
||||||
|
# Update query lengths. Note that we update only queries and not seqs,
|
||||||
|
# since tensors may be padded due to captured cuda graph batch size
|
||||||
|
for i in range(num_queries):
|
||||||
|
self.seq_lens[i] += 1
|
||||||
|
self.max_decode_seq_len = max(self.seq_lens)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionMetadataBuilder(
|
class FlashAttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
AttentionMetadataBuilder[FlashAttentionMetadata]):
|
||||||
|
|||||||
@ -97,38 +97,6 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
self.flashinfer_prefill_workspace_buffer = None
|
self.flashinfer_prefill_workspace_buffer = None
|
||||||
self.flashinfer_prefill_wrapper = None
|
self.flashinfer_prefill_wrapper = None
|
||||||
|
|
||||||
def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
|
|
||||||
num_queries):
|
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
|
||||||
|
|
||||||
if num_seqs != num_queries:
|
|
||||||
assert num_seqs > num_queries
|
|
||||||
assert attn_metadata.use_cuda_graph
|
|
||||||
|
|
||||||
assert attn_metadata.num_prefills == 0
|
|
||||||
assert attn_metadata.num_prefill_tokens == 0
|
|
||||||
assert attn_metadata.num_decode_tokens == num_seqs
|
|
||||||
assert attn_metadata.slot_mapping.shape == (num_seqs, )
|
|
||||||
|
|
||||||
assert len(attn_metadata.seq_lens) == num_seqs
|
|
||||||
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
|
|
||||||
assert attn_metadata.max_query_len == 1
|
|
||||||
assert attn_metadata.max_prefill_seq_len == 0
|
|
||||||
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)
|
|
||||||
|
|
||||||
assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
|
|
||||||
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )
|
|
||||||
|
|
||||||
assert attn_metadata.context_lens_tensor.shape == (num_queries, )
|
|
||||||
|
|
||||||
assert attn_metadata.block_tables.shape[0] == num_seqs
|
|
||||||
|
|
||||||
# Update query lengths. Note that we update only queries and not seqs,
|
|
||||||
# since tensors may be padded due to captured cuda graph batch size
|
|
||||||
for i in range(num_queries):
|
|
||||||
attn_metadata.seq_lens[i] += 1
|
|
||||||
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)
|
|
||||||
|
|
||||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||||
num_queries):
|
num_queries):
|
||||||
|
|
||||||
@ -166,7 +134,7 @@ class TP1DraftModelRunner(ModelRunner):
|
|||||||
# Update attn_metadata
|
# Update attn_metadata
|
||||||
attn_metadata = model_input.attn_metadata
|
attn_metadata = model_input.attn_metadata
|
||||||
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
assert isinstance(attn_metadata, FlashAttentionMetadata)
|
||||||
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
|
attn_metadata.advance_step(num_seqs, num_queries)
|
||||||
|
|
||||||
# Update GPU tensors
|
# Update GPU tensors
|
||||||
ops.advance_step(num_seqs=num_seqs,
|
ops.advance_step(num_seqs=num_seqs,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user