extend cuda graph size for H200 (#7894)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
6b3421567d
commit
c334b1898b
@ -60,10 +60,14 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
LORA_WARMUP_RANK = 8
|
LORA_WARMUP_RANK = 8
|
||||||
_BATCH_SIZE_ALIGNMENT = 8
|
_BATCH_SIZE_ALIGNMENT = 8
|
||||||
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
# all the token sizes that **can** be captured by cudagraph.
|
||||||
|
# they can be arbitrarily large.
|
||||||
|
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
|
||||||
|
# the actual sizes to capture will be determined by the model,
|
||||||
|
# depending on the model's max_num_seqs.
|
||||||
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
|
||||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
|
||||||
]
|
]
|
||||||
_NUM_WARMUP_ITERS = 2
|
_NUM_WARMUP_ITERS = 2
|
||||||
|
|
||||||
@ -660,7 +664,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
def _use_captured_graph(self, batch_size: int,
|
def _use_captured_graph(self, batch_size: int,
|
||||||
max_decode_seq_len: int) -> bool:
|
max_decode_seq_len: int) -> bool:
|
||||||
return (self.decode_only and not self.runner.model_config.enforce_eager
|
return (self.decode_only and not self.runner.model_config.enforce_eager
|
||||||
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
|
and batch_size <= self.runner.max_batchsize_to_capture
|
||||||
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
|
||||||
|
|
||||||
def build(self) -> ModelInputForGPU:
|
def build(self) -> ModelInputForGPU:
|
||||||
@ -846,6 +850,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.sliding_window = model_config.get_sliding_window()
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
self.block_size = cache_config.block_size
|
self.block_size = cache_config.block_size
|
||||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||||
|
self.max_batchsize_to_capture = _get_max_graph_batch_size(
|
||||||
|
self.scheduler_config.max_num_seqs)
|
||||||
|
|
||||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
||||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
@ -863,7 +869,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
# The shape of the cached block table will be
|
# The shape of the cached block table will be
|
||||||
# (max batch size to capture, max context len to capture / block size).
|
# (max batch size to capture, max context len to capture / block size).
|
||||||
self.graph_block_tables = np.zeros(
|
self.graph_block_tables = np.zeros(
|
||||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||||
self.parallel_config)
|
self.parallel_config)
|
||||||
@ -1218,7 +1224,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Prepare dummy inputs. These will be reused for all batch sizes.
|
# Prepare dummy inputs. These will be reused for all batch sizes.
|
||||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
max_batch_size = self.max_batchsize_to_capture
|
||||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
|
|
||||||
@ -1246,8 +1252,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
None
|
None
|
||||||
] * self.parallel_config.pipeline_parallel_size
|
] * self.parallel_config.pipeline_parallel_size
|
||||||
|
|
||||||
graph_batch_size = _get_graph_batch_size(
|
graph_batch_size = self.max_batchsize_to_capture
|
||||||
self.scheduler_config.max_num_seqs)
|
|
||||||
batch_size_capture_list = [
|
batch_size_capture_list = [
|
||||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||||
]
|
]
|
||||||
@ -1673,3 +1678,22 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|||||||
else:
|
else:
|
||||||
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
|
||||||
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_max_graph_batch_size(max_num_seqs: int) -> int:
|
||||||
|
"""
|
||||||
|
max_num_seqs: Maximum number of sequences in a batch.
|
||||||
|
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
|
||||||
|
|
||||||
|
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
|
||||||
|
which will deal with some edge cases like 1, 2, 4.
|
||||||
|
|
||||||
|
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
|
||||||
|
if not, it means the padded size is larger than the largest size in
|
||||||
|
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
|
||||||
|
"""
|
||||||
|
padded_size = _get_graph_batch_size(max_num_seqs)
|
||||||
|
if padded_size in _BATCH_SIZES_TO_CAPTURE:
|
||||||
|
return padded_size
|
||||||
|
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
|
||||||
|
return _BATCH_SIZES_TO_CAPTURE[-1]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user