[Core] CUDA Graphs for Multi-Step + Chunked-Prefill (#8645)
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
7f60520deb
commit
afb050b29d
@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
|
|||||||
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
long const* sampled_token_ids_ptr, long* input_positions_ptr,
|
||||||
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
|
||||||
int64_t const block_tables_stride) {
|
int64_t const block_tables_stride) {
|
||||||
|
int const n_pad = num_seqs - num_queries;
|
||||||
|
if (n_pad && blockIdx.x == 0) {
|
||||||
|
// Handle cuda graph padding
|
||||||
|
int const offset = num_queries;
|
||||||
|
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
|
||||||
|
input_tokens_ptr[offset + i] = 0;
|
||||||
|
input_positions_ptr[offset + i] = 0;
|
||||||
|
slot_mapping_ptr[offset + i] = -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||||
|
|
||||||
if (blockIdx.x >= num_query_blocks) {
|
if (blockIdx.x >= num_query_blocks) {
|
||||||
|
|||||||
@ -500,6 +500,30 @@ class FlashAttentionMetadataBuilder(
|
|||||||
seq_len, context_len, start_idx,
|
seq_len, context_len, start_idx,
|
||||||
self.block_size, inter_data.block_tables)
|
self.block_size, inter_data.block_tables)
|
||||||
|
|
||||||
|
def _get_graph_runner_block_tables(
|
||||||
|
self, num_seqs: int,
|
||||||
|
block_tables: List[List[int]]) -> torch.Tensor:
|
||||||
|
# The shape of graph_block_tables is
|
||||||
|
# [max batch size, max context len // block size].
|
||||||
|
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||||
|
assert max_batch_size >= num_seqs
|
||||||
|
|
||||||
|
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||||
|
for i, block_table in enumerate(block_tables):
|
||||||
|
if block_table:
|
||||||
|
num_blocks = len(block_table)
|
||||||
|
if num_blocks <= max_blocks:
|
||||||
|
graph_block_tables[i, :num_blocks] = block_table
|
||||||
|
else:
|
||||||
|
# It may be possible to have more blocks allocated due
|
||||||
|
# to lookahead slots of multi-step, however, they are
|
||||||
|
# not used anyway, so can be safely ignored.
|
||||||
|
graph_block_tables[
|
||||||
|
i, :max_blocks] = block_table[:max_blocks]
|
||||||
|
|
||||||
|
return torch.from_numpy(graph_block_tables).to(
|
||||||
|
device=self.runner.device, non_blocking=True)
|
||||||
|
|
||||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||||
cuda_graph_pad_size: int, batch_size: int):
|
cuda_graph_pad_size: int, batch_size: int):
|
||||||
"""Build attention metadata with on-device tensors.
|
"""Build attention metadata with on-device tensors.
|
||||||
@ -533,29 +557,13 @@ class FlashAttentionMetadataBuilder(
|
|||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
num_decode_tokens = self.num_decode_tokens
|
num_decode_tokens = self.num_decode_tokens
|
||||||
|
|
||||||
|
num_seqs = len(seq_lens)
|
||||||
if use_captured_graph:
|
if use_captured_graph:
|
||||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||||
num_decode_tokens = batch_size
|
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||||
|
block_tables = self._get_graph_runner_block_tables(
|
||||||
# The shape of graph_block_tables is
|
num_seqs, self.block_tables)
|
||||||
# [max batch size, max context len // block size].
|
|
||||||
input_block_tables = self.runner.graph_block_tables[:batch_size]
|
|
||||||
max_blocks = input_block_tables.shape[1]
|
|
||||||
for i, block_table in enumerate(self.block_tables):
|
|
||||||
if block_table:
|
|
||||||
num_blocks = len(block_table)
|
|
||||||
if num_blocks <= max_blocks:
|
|
||||||
input_block_tables[i, :num_blocks] = block_table
|
|
||||||
else:
|
|
||||||
# It may be possible to have more blocks allocated due
|
|
||||||
# to lookahead slots of multi-step, however, they are
|
|
||||||
# not used anyway, so can be safely ignored.
|
|
||||||
input_block_tables[
|
|
||||||
i, :max_blocks] = block_table[:max_blocks]
|
|
||||||
|
|
||||||
block_tables = torch.from_numpy(input_block_tables).to(
|
|
||||||
device=device, non_blocking=True)
|
|
||||||
else:
|
else:
|
||||||
block_tables = make_tensor_with_pad(
|
block_tables = make_tensor_with_pad(
|
||||||
self.block_tables,
|
self.block_tables,
|
||||||
|
|||||||
@ -712,14 +712,62 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
def _use_captured_graph(self,
|
def _use_captured_graph(self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
decode_only: bool,
|
||||||
max_decode_seq_len: int,
|
max_decode_seq_len: int,
|
||||||
max_encoder_seq_len: int = 0) -> bool:
|
max_encoder_seq_len: int = 0) -> bool:
|
||||||
return (self.decode_only and not self.runner.model_config.enforce_eager
|
return (decode_only and not self.runner.model_config.enforce_eager
|
||||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
||||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
|
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
|
||||||
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
|
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
|
||||||
and batch_size <= self.runner.max_batchsize_to_capture)
|
and batch_size <= self.runner.max_batchsize_to_capture)
|
||||||
|
|
||||||
|
def _get_cuda_graph_pad_size(self,
|
||||||
|
num_seqs: int,
|
||||||
|
max_decode_seq_len: int,
|
||||||
|
max_encoder_seq_len: int = 0) -> int:
|
||||||
|
"""
|
||||||
|
Determine the number of padding sequences required for running in
|
||||||
|
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
|
||||||
|
|
||||||
|
In the multi-step + chunked-prefill case, only the first step
|
||||||
|
has Prefills (if any). The rest of the steps are guaranteed to be all
|
||||||
|
decodes. In this case, we set up the padding as if all the sequences
|
||||||
|
are decodes so we may run all steps except the first step in CUDA graph
|
||||||
|
mode. The padding is accounted for in the multi-step `advance_step`
|
||||||
|
family of functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_seqs (int): Number of sequences scheduled to run.
|
||||||
|
max_decode_seq_len (int): Greatest of all the decode sequence
|
||||||
|
lengths. Used only in checking the viablility of using
|
||||||
|
CUDA graphs.
|
||||||
|
max_encoder_seq_len (int, optional): Greatest of all the encode
|
||||||
|
sequence lengths. Defaults to 0. Used only in checking the
|
||||||
|
viability of using CUDA graphs.
|
||||||
|
Returns:
|
||||||
|
int: Returns the determined number of padding sequences. If
|
||||||
|
CUDA graphs is not viable, returns -1.
|
||||||
|
"""
|
||||||
|
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
|
||||||
|
self.runner.scheduler_config.chunked_prefill_enabled
|
||||||
|
decode_only = self.decode_only or is_mscp
|
||||||
|
if not decode_only:
|
||||||
|
# Early exit so we can treat num_seqs as the batch_size below.
|
||||||
|
return -1
|
||||||
|
|
||||||
|
# batch_size out of this function refers to the number of input
|
||||||
|
# tokens being scheduled. This conflation of num_seqs as batch_size
|
||||||
|
# is valid as this is a decode-only case.
|
||||||
|
batch_size = num_seqs
|
||||||
|
if not self._use_captured_graph(batch_size, decode_only,
|
||||||
|
max_decode_seq_len,
|
||||||
|
max_encoder_seq_len):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
graph_batch_size = _get_graph_batch_size(batch_size)
|
||||||
|
assert graph_batch_size >= batch_size
|
||||||
|
return graph_batch_size - batch_size
|
||||||
|
|
||||||
def build(self) -> ModelInputForGPU:
|
def build(self) -> ModelInputForGPU:
|
||||||
"""Finalize the builder intermediate data and
|
"""Finalize the builder intermediate data and
|
||||||
create on-device tensors.
|
create on-device tensors.
|
||||||
@ -778,21 +826,17 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for data in self.inter_data_list
|
for data in self.inter_data_list
|
||||||
}
|
}
|
||||||
|
|
||||||
batch_size = len(input_tokens)
|
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
|
||||||
use_captured_graph = self._use_captured_graph(
|
num_seqs=len(seq_lens),
|
||||||
batch_size,
|
max_decode_seq_len=max_encoder_seq_len,
|
||||||
max_decode_seq_len,
|
|
||||||
max_encoder_seq_len=max_encoder_seq_len)
|
max_encoder_seq_len=max_encoder_seq_len)
|
||||||
|
|
||||||
|
batch_size = len(input_tokens)
|
||||||
|
if cuda_graph_pad_size != -1:
|
||||||
# If cuda graph can be used, pad tensors accordingly.
|
# If cuda graph can be used, pad tensors accordingly.
|
||||||
# See `capture_model` API for more details.
|
# See `capture_model` API for more details.
|
||||||
# vLLM uses cuda graph only for decoding requests.
|
# vLLM uses cuda graph only for decoding requests.
|
||||||
cuda_graph_pad_size = -1
|
batch_size += cuda_graph_pad_size
|
||||||
if use_captured_graph:
|
|
||||||
graph_batch_size = _get_graph_batch_size(batch_size)
|
|
||||||
assert graph_batch_size >= batch_size
|
|
||||||
cuda_graph_pad_size = graph_batch_size - batch_size
|
|
||||||
batch_size = graph_batch_size
|
|
||||||
|
|
||||||
# Tokens and positions.
|
# Tokens and positions.
|
||||||
if cuda_graph_pad_size:
|
if cuda_graph_pad_size:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user