[Model][Jamba] Mamba cache single buffer (#6739)

Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
Mor Zusman 2024-08-09 17:07:06 +03:00 committed by GitHub
parent b4e9528f95
commit 07ab160741
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 148 additions and 124 deletions

View File

@ -609,12 +609,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
# Current step used indices
self.current_indices: List[int] = []
# Used to track and store by the Mamba cache between steps. # Used to track and store by the Mamba cache between steps.
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple() self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Used as an input_buffer for the CUDA graph runs.
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
# Maps between the request id and a dict that maps between the seq_id # Maps between the request id and a dict that maps between the seq_id
# and its index inside the self.mamba_cache # and its index inside the self.mamba_cache
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
@ -644,95 +640,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
batch_size = input_ids.shape[0] batch_size = input_ids.shape[0]
if attn_metadata.prefill_metadata: if attn_metadata.prefill_metadata:
batch_size = len(request_ids_to_seq_ids) batch_size = len(request_ids_to_seq_ids)
( mamba_cache = self._prepare_current_run_mamba_cache(
current_seqlen_agnostic_cache, request_ids_to_seq_ids, batch_size, finished_requests_ids)
indices,
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
batch_size,
finished_requests_ids)
else: else:
# CUDA graph capturing runs # CUDA graph capturing runs
current_seqlen_agnostic_cache, indices = ( mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
kwargs["seqlen_agnostic_capture_inputs"],
[],
)
self.current_indices = indices
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, attn_metadata, mamba_cache[0],
current_seqlen_agnostic_cache[0], mamba_cache[1])
current_seqlen_agnostic_cache[1])
if "seqlen_agnostic_capture_inputs" not in kwargs:
self._copy_mamba_cache_by_indices(self.current_indices,
current_seqlen_agnostic_cache)
return hidden_states return hidden_states
def _copy_mamba_cache_by_indices( def _swap_mamba_cache(self, from_index: int, to_index: int):
self, indices: List[int],
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
for i, offset in enumerate(indices):
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
def _copy_mamba_cache(self, index_to: int, index_from: int,
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
assert len(self.mamba_cache) > 0 assert len(self.mamba_cache) > 0
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer): for cache_t in self.mamba_cache:
cache_t[:, index_to].copy_(from_buffer_t[:, index_from], cache_t[:, [to_index,from_index]] = \
cache_t[:, [from_index,to_index]]
def _copy_mamba_cache(self, from_index: int, to_index: int):
assert len(self.mamba_cache) > 0
for cache_t in self.mamba_cache:
cache_t[:, to_index].copy_(cache_t[:, from_index],
non_blocking=True) non_blocking=True)
def _assign_seq_id_to_mamba_cache(self, cur_rid: str, def _move_out_if_already_occupied(self, index: int,
seqs_id: List[int]) -> List[int]: all_occupied_indices: List[int]):
indices_for_current_run = [] if index in all_occupied_indices:
for seq_id in seqs_id: first_free_index = self._first_free_index_in_mamba_cache()
if cur_rid not in self.mamba_cache_indices_mapping: # In case occupied, move the occupied to a new empty block
self.mamba_cache_indices_mapping[cur_rid] = {} self._move_cache_index_and_mappings(from_index=index,
first_free_index = self._first_free_index_in_mamba_cache() to_index=first_free_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
## case of decoding n>1, copy prefill cache to decoding indices
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
first_free_index = self._first_free_index_in_mamba_cache()
index_exist = list(seq_ids2indices.values())[0]
self._copy_mamba_cache(index_from=index_exist,
index_to=first_free_index,
from_buffer=self.mamba_cache)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = first_free_index
index_for_current_run = first_free_index
else:
index_for_current_run = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
indices_for_current_run.append(index_for_current_run) def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
return indices_for_current_run seq_id: int,
destination_index: int):
"""
Assign (req_id,seq_id) pair to a `destination_index` index, if
already occupied, move the occupying index to a free index.
"""
all_occupied_indices = self._get_all_occupied_indices()
if cur_rid not in self.mamba_cache_indices_mapping:
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
self.mamba_cache_indices_mapping[cur_rid] = {
seq_id: destination_index
}
elif seq_id not in (seq_ids2indices :=
self.mamba_cache_indices_mapping[cur_rid]):
# parallel sampling , where n > 1, assume prefill have
# already happened now we only need to copy the already
# existing cache into the siblings seq_ids caches
self._move_out_if_already_occupied(
index=destination_index,
all_occupied_indices=all_occupied_indices)
index_exists = list(seq_ids2indices.values())[0]
# case of decoding n>1, copy prefill cache to decoding indices
self._copy_mamba_cache(from_index=index_exists,
to_index=destination_index)
self.mamba_cache_indices_mapping[cur_rid][
seq_id] = destination_index
else:
# already exists
cache_index_already_exists = self.mamba_cache_indices_mapping[
cur_rid][seq_id]
if cache_index_already_exists != destination_index:
# In case the seq id already exists but not in
# the right destination, swap it with what's occupying it
self._swap_pair_indices_and_mappings(
from_index=cache_index_already_exists,
to_index=destination_index)
def _prepare_current_run_mamba_cache( def _prepare_current_run_mamba_cache(
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int, self, request_ids_to_seq_ids: Dict[str, list[int]],
finished_requests_ids: List[str] batch_size: int, finished_requests_ids: List[str]):
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]: running_indices = []
indices_for_current_run = [] request_ids_to_seq_ids_flatten = [
for request_id, seqs_id in request_ids_to_seq_ids.items(): (req_id, seq_id)
for req_id, seq_ids in request_ids_to_seq_ids.items()
for seq_id in seq_ids
]
for dest_index, (request_id,
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
if request_id in finished_requests_ids: if request_id in finished_requests_ids:
# Do not allocate cache for requests that run # Do not allocate cache index for requests that run
# and finish right after # and finish right after
continue continue
indices_for_current_run += self._assign_seq_id_to_mamba_cache( self._assign_seq_id_to_mamba_cache_in_specific_dest(
request_id, seqs_id) request_id, seq_id, dest_index)
## Pad the batch in case of running batch that was not captured via CG running_indices.append(dest_index)
padded_indices = indices_for_current_run.copy()
pad_index = self._first_free_index_in_mamba_cache()
for _ in range(batch_size - len(indices_for_current_run)): self._clean_up_first_bs_blocks(batch_size, running_indices)
padded_indices.append(pad_index) conv_state = self.mamba_cache[0][:, :batch_size]
temporal_state = self.mamba_cache[1][:, :batch_size]
conv_state = self.mamba_cache[0][:, padded_indices] return (conv_state, temporal_state)
temporal_state = self.mamba_cache[1][:, padded_indices]
return (conv_state, temporal_state), indices_for_current_run def _get_all_occupied_indices(self):
return [
cache_idx
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
for cache_idx in seq_ids2indices.values()
]
def _clean_up_first_bs_blocks(self, batch_size: int,
indices_for_current_run: List[int]):
# move out all of the occupied but currently not running blocks
# outside of the first n blocks
destination_indices = set([range(batch_size)])
max_possible_batch_size = self.mamba_cache[0].shape[1]
for destination_index in destination_indices:
if destination_index in self._get_all_occupied_indices() and \
destination_index not in indices_for_current_run:
# move not running indices outside of the batch
all_other_indices = list(
range(batch_size, max_possible_batch_size))
first_avail_index = self._first_free_index_in_mamba_cache(
all_other_indices)
self._swap_indices(from_index=destination_index,
to_index=first_avail_index)
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
self._update_mapping_index(from_index=from_index, to_index=to_index)
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
self._swap_mapping_index(from_index=from_index, to_index=to_index)
def _swap_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
elif to_index == index:
seq_ids2index.update({seq_id: from_index})
def _update_mapping_index(self, from_index: int, to_index: int):
for seq_ids2index in self.mamba_cache_indices_mapping.values():
for seq_id, index in seq_ids2index.items():
if from_index == index:
seq_ids2index.update({seq_id: to_index})
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
""" """
@ -747,28 +796,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
self._release_mamba_cache(finished_requests_ids) self._release_mamba_cache(finished_requests_ids)
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
cg_batch_size = input_buffers['input_ids'].shape[0] cg_batch_size = input_buffers['input_ids'].shape[0]
( self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
current_mamba_cache, cg_batch_size,
indices, finished_requests_ids)
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
cg_batch_size,
finished_requests_ids)
self.current_indices = indices
for input_buffer, current_cache_buffer in zip(
input_buffers["seqlen_agnostic_capture_inputs"],
current_mamba_cache):
input_buffer.copy_(current_cache_buffer, non_blocking=True)
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
"""
Copy the relevant Mamba cache from the CUDA graph input_buffers
back to the JambaForCausalLM.mamba_cache after CUDA
graph replay run is done.
"""
self._copy_mamba_cache_by_indices(
self.current_indices,
input_buffers["seqlen_agnostic_capture_inputs"])
def get_seqlen_agnostic_capture_inputs(self, batch_size: int): def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
""" """
@ -776,26 +806,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
The buffer is used to maintain the Mamba Cache during the CUDA graph The buffer is used to maintain the Mamba Cache during the CUDA graph
replay runs. replay runs.
""" """
return tuple(buffer[:, :batch_size] return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
for buffer in self.mamba_gc_cache_buffer)
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]): def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
for req_id in finished_seq_groups_req_ids: for req_id in finished_seq_groups_req_ids:
if req_id in self.mamba_cache_indices_mapping: if req_id in self.mamba_cache_indices_mapping:
self.mamba_cache_indices_mapping.pop(req_id) self.mamba_cache_indices_mapping.pop(req_id)
def _first_free_index_in_mamba_cache(self) -> int: def _first_free_index_in_mamba_cache(
if self.mamba_cache: self, indices_range: Optional[List[int]] = None) -> int:
assert self.mamba_cache is not None
if indices_range is None:
max_possible_batch_size = self.mamba_cache[0].shape[1] max_possible_batch_size = self.mamba_cache[0].shape[1]
occupied = [ indices_range = list(range(max_possible_batch_size))
id for seq_ids in self.mamba_cache_indices_mapping.values() all_occupied_indices = self._get_all_occupied_indices()
for id in seq_ids.values() for i in indices_range:
] if i not in all_occupied_indices:
first_free_index = [ return i
i not in occupied for i in range(max_possible_batch_size) raise Exception("Couldn't find a free spot in the mamba cache! This"
].index(True) "should never happen")
return first_free_index
return 0
def _get_mamba_cache_shape( def _get_mamba_cache_shape(
self self
@ -819,20 +848,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
[layer_type == "mamba" for layer_type in layers_type]) [layer_type == "mamba" for layer_type in layers_type])
max_batch_size = (_get_graph_batch_size( max_batch_size = (_get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config else self.scheduler_config.max_num_seqs) if self.scheduler_config else
max(_BATCH_SIZES_TO_CAPTURE)) + 10 max(_BATCH_SIZES_TO_CAPTURE) + 2)
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape() conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
assert conv_state_shape is not None and temporal_state_shape is not None assert conv_state_shape is not None and temporal_state_shape is not None
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]: self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
buffer = (torch.empty(size=(mamba_layers, max_batch_size) + conv_state_shape,
conv_state_shape, dtype=dtype,
dtype=dtype, device="cuda"),
device="cuda"), torch.empty(size=(mamba_layers, max_batch_size) +
torch.empty(size=(mamba_layers, max_batch_size) + temporal_state_shape,
temporal_state_shape, dtype=dtype,
dtype=dtype, device="cuda"))
device="cuda"))
setattr(self, buffername, buffer)
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:

View File

@ -1711,9 +1711,6 @@ class CUDAGraphRunner:
non_blocking=True) non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
**kwargs)
# Return the output tensor. # Return the output tensor.
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
return self.output_buffers["hidden_states"] return self.output_buffers["hidden_states"]