[Model][Jamba] Mamba cache single buffer (#6739)
Co-authored-by: Mor Zusman <morz@ai21.com>
This commit is contained in:
parent
b4e9528f95
commit
07ab160741
@ -609,12 +609,8 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
# compatibility
|
||||
if not lora_config else lora_config.lora_vocab_padding_size,
|
||||
)
|
||||
# Current step used indices
|
||||
self.current_indices: List[int] = []
|
||||
# Used to track and store by the Mamba cache between steps.
|
||||
self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
||||
# Used as an input_buffer for the CUDA graph runs.
|
||||
self.mamba_gc_cache_buffer: Tuple[torch.Tensor, torch.Tensor] = tuple()
|
||||
# Maps between the request id and a dict that maps between the seq_id
|
||||
# and its index inside the self.mamba_cache
|
||||
self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
|
||||
@ -644,95 +640,148 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
batch_size = input_ids.shape[0]
|
||||
if attn_metadata.prefill_metadata:
|
||||
batch_size = len(request_ids_to_seq_ids)
|
||||
(
|
||||
current_seqlen_agnostic_cache,
|
||||
indices,
|
||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
batch_size,
|
||||
finished_requests_ids)
|
||||
mamba_cache = self._prepare_current_run_mamba_cache(
|
||||
request_ids_to_seq_ids, batch_size, finished_requests_ids)
|
||||
else:
|
||||
# CUDA graph capturing runs
|
||||
current_seqlen_agnostic_cache, indices = (
|
||||
kwargs["seqlen_agnostic_capture_inputs"],
|
||||
[],
|
||||
)
|
||||
self.current_indices = indices
|
||||
mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
|
||||
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata,
|
||||
current_seqlen_agnostic_cache[0],
|
||||
current_seqlen_agnostic_cache[1])
|
||||
|
||||
if "seqlen_agnostic_capture_inputs" not in kwargs:
|
||||
self._copy_mamba_cache_by_indices(self.current_indices,
|
||||
current_seqlen_agnostic_cache)
|
||||
|
||||
attn_metadata, mamba_cache[0],
|
||||
mamba_cache[1])
|
||||
return hidden_states
|
||||
|
||||
def _copy_mamba_cache_by_indices(
|
||||
self, indices: List[int],
|
||||
current_seqlen_agnostic_cache: Tuple[torch.Tensor, torch.Tensor]):
|
||||
for i, offset in enumerate(indices):
|
||||
self._copy_mamba_cache(offset, i, current_seqlen_agnostic_cache)
|
||||
|
||||
def _copy_mamba_cache(self, index_to: int, index_from: int,
|
||||
from_buffer: Tuple[torch.Tensor, torch.Tensor]):
|
||||
def _swap_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for (cache_t, from_buffer_t) in zip(self.mamba_cache, from_buffer):
|
||||
cache_t[:, index_to].copy_(from_buffer_t[:, index_from],
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, [to_index,from_index]] = \
|
||||
cache_t[:, [from_index,to_index]]
|
||||
|
||||
def _copy_mamba_cache(self, from_index: int, to_index: int):
|
||||
assert len(self.mamba_cache) > 0
|
||||
for cache_t in self.mamba_cache:
|
||||
cache_t[:, to_index].copy_(cache_t[:, from_index],
|
||||
non_blocking=True)
|
||||
|
||||
def _assign_seq_id_to_mamba_cache(self, cur_rid: str,
|
||||
seqs_id: List[int]) -> List[int]:
|
||||
indices_for_current_run = []
|
||||
for seq_id in seqs_id:
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {}
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = first_free_index
|
||||
index_for_current_run = first_free_index
|
||||
## case of decoding n>1, copy prefill cache to decoding indices
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
index_exist = list(seq_ids2indices.values())[0]
|
||||
self._copy_mamba_cache(index_from=index_exist,
|
||||
index_to=first_free_index,
|
||||
from_buffer=self.mamba_cache)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = first_free_index
|
||||
index_for_current_run = first_free_index
|
||||
else:
|
||||
index_for_current_run = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
def _move_out_if_already_occupied(self, index: int,
|
||||
all_occupied_indices: List[int]):
|
||||
if index in all_occupied_indices:
|
||||
first_free_index = self._first_free_index_in_mamba_cache()
|
||||
# In case occupied, move the occupied to a new empty block
|
||||
self._move_cache_index_and_mappings(from_index=index,
|
||||
to_index=first_free_index)
|
||||
|
||||
indices_for_current_run.append(index_for_current_run)
|
||||
return indices_for_current_run
|
||||
def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
|
||||
seq_id: int,
|
||||
destination_index: int):
|
||||
"""
|
||||
Assign (req_id,seq_id) pair to a `destination_index` index, if
|
||||
already occupied, move the occupying index to a free index.
|
||||
"""
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
if cur_rid not in self.mamba_cache_indices_mapping:
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
self.mamba_cache_indices_mapping[cur_rid] = {
|
||||
seq_id: destination_index
|
||||
}
|
||||
elif seq_id not in (seq_ids2indices :=
|
||||
self.mamba_cache_indices_mapping[cur_rid]):
|
||||
# parallel sampling , where n > 1, assume prefill have
|
||||
# already happened now we only need to copy the already
|
||||
# existing cache into the siblings seq_ids caches
|
||||
self._move_out_if_already_occupied(
|
||||
index=destination_index,
|
||||
all_occupied_indices=all_occupied_indices)
|
||||
index_exists = list(seq_ids2indices.values())[0]
|
||||
# case of decoding n>1, copy prefill cache to decoding indices
|
||||
self._copy_mamba_cache(from_index=index_exists,
|
||||
to_index=destination_index)
|
||||
self.mamba_cache_indices_mapping[cur_rid][
|
||||
seq_id] = destination_index
|
||||
else:
|
||||
# already exists
|
||||
cache_index_already_exists = self.mamba_cache_indices_mapping[
|
||||
cur_rid][seq_id]
|
||||
if cache_index_already_exists != destination_index:
|
||||
# In case the seq id already exists but not in
|
||||
# the right destination, swap it with what's occupying it
|
||||
self._swap_pair_indices_and_mappings(
|
||||
from_index=cache_index_already_exists,
|
||||
to_index=destination_index)
|
||||
|
||||
def _prepare_current_run_mamba_cache(
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]], batch_size: int,
|
||||
finished_requests_ids: List[str]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], List[int]]:
|
||||
indices_for_current_run = []
|
||||
for request_id, seqs_id in request_ids_to_seq_ids.items():
|
||||
self, request_ids_to_seq_ids: Dict[str, list[int]],
|
||||
batch_size: int, finished_requests_ids: List[str]):
|
||||
running_indices = []
|
||||
request_ids_to_seq_ids_flatten = [
|
||||
(req_id, seq_id)
|
||||
for req_id, seq_ids in request_ids_to_seq_ids.items()
|
||||
for seq_id in seq_ids
|
||||
]
|
||||
for dest_index, (request_id,
|
||||
seq_id) in enumerate(request_ids_to_seq_ids_flatten):
|
||||
if request_id in finished_requests_ids:
|
||||
# Do not allocate cache for requests that run
|
||||
# Do not allocate cache index for requests that run
|
||||
# and finish right after
|
||||
continue
|
||||
indices_for_current_run += self._assign_seq_id_to_mamba_cache(
|
||||
request_id, seqs_id)
|
||||
## Pad the batch in case of running batch that was not captured via CG
|
||||
padded_indices = indices_for_current_run.copy()
|
||||
pad_index = self._first_free_index_in_mamba_cache()
|
||||
self._assign_seq_id_to_mamba_cache_in_specific_dest(
|
||||
request_id, seq_id, dest_index)
|
||||
running_indices.append(dest_index)
|
||||
|
||||
for _ in range(batch_size - len(indices_for_current_run)):
|
||||
padded_indices.append(pad_index)
|
||||
self._clean_up_first_bs_blocks(batch_size, running_indices)
|
||||
conv_state = self.mamba_cache[0][:, :batch_size]
|
||||
temporal_state = self.mamba_cache[1][:, :batch_size]
|
||||
|
||||
conv_state = self.mamba_cache[0][:, padded_indices]
|
||||
temporal_state = self.mamba_cache[1][:, padded_indices]
|
||||
return (conv_state, temporal_state)
|
||||
|
||||
return (conv_state, temporal_state), indices_for_current_run
|
||||
def _get_all_occupied_indices(self):
|
||||
return [
|
||||
cache_idx
|
||||
for seq_ids2indices in self.mamba_cache_indices_mapping.values()
|
||||
for cache_idx in seq_ids2indices.values()
|
||||
]
|
||||
|
||||
def _clean_up_first_bs_blocks(self, batch_size: int,
|
||||
indices_for_current_run: List[int]):
|
||||
# move out all of the occupied but currently not running blocks
|
||||
# outside of the first n blocks
|
||||
destination_indices = set([range(batch_size)])
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
for destination_index in destination_indices:
|
||||
if destination_index in self._get_all_occupied_indices() and \
|
||||
destination_index not in indices_for_current_run:
|
||||
# move not running indices outside of the batch
|
||||
all_other_indices = list(
|
||||
range(batch_size, max_possible_batch_size))
|
||||
first_avail_index = self._first_free_index_in_mamba_cache(
|
||||
all_other_indices)
|
||||
self._swap_indices(from_index=destination_index,
|
||||
to_index=first_avail_index)
|
||||
|
||||
def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
|
||||
self._copy_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._update_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
|
||||
self._swap_mamba_cache(from_index=from_index, to_index=to_index)
|
||||
self._swap_mapping_index(from_index=from_index, to_index=to_index)
|
||||
|
||||
def _swap_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
elif to_index == index:
|
||||
seq_ids2index.update({seq_id: from_index})
|
||||
|
||||
def _update_mapping_index(self, from_index: int, to_index: int):
|
||||
for seq_ids2index in self.mamba_cache_indices_mapping.values():
|
||||
for seq_id, index in seq_ids2index.items():
|
||||
if from_index == index:
|
||||
seq_ids2index.update({seq_id: to_index})
|
||||
return
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
@ -747,28 +796,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
self._release_mamba_cache(finished_requests_ids)
|
||||
request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
|
||||
cg_batch_size = input_buffers['input_ids'].shape[0]
|
||||
(
|
||||
current_mamba_cache,
|
||||
indices,
|
||||
) = self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
cg_batch_size,
|
||||
finished_requests_ids)
|
||||
self.current_indices = indices
|
||||
|
||||
for input_buffer, current_cache_buffer in zip(
|
||||
input_buffers["seqlen_agnostic_capture_inputs"],
|
||||
current_mamba_cache):
|
||||
input_buffer.copy_(current_cache_buffer, non_blocking=True)
|
||||
|
||||
def copy_outputs_after_cuda_graphs(self, input_buffers, **kwargs):
|
||||
"""
|
||||
Copy the relevant Mamba cache from the CUDA graph input_buffers
|
||||
back to the JambaForCausalLM.mamba_cache after CUDA
|
||||
graph replay run is done.
|
||||
"""
|
||||
self._copy_mamba_cache_by_indices(
|
||||
self.current_indices,
|
||||
input_buffers["seqlen_agnostic_capture_inputs"])
|
||||
self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
|
||||
cg_batch_size,
|
||||
finished_requests_ids)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
"""
|
||||
@ -776,26 +806,25 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
The buffer is used to maintain the Mamba Cache during the CUDA graph
|
||||
replay runs.
|
||||
"""
|
||||
return tuple(buffer[:, :batch_size]
|
||||
for buffer in self.mamba_gc_cache_buffer)
|
||||
return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
|
||||
|
||||
def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
|
||||
for req_id in finished_seq_groups_req_ids:
|
||||
if req_id in self.mamba_cache_indices_mapping:
|
||||
self.mamba_cache_indices_mapping.pop(req_id)
|
||||
|
||||
def _first_free_index_in_mamba_cache(self) -> int:
|
||||
if self.mamba_cache:
|
||||
def _first_free_index_in_mamba_cache(
|
||||
self, indices_range: Optional[List[int]] = None) -> int:
|
||||
assert self.mamba_cache is not None
|
||||
if indices_range is None:
|
||||
max_possible_batch_size = self.mamba_cache[0].shape[1]
|
||||
occupied = [
|
||||
id for seq_ids in self.mamba_cache_indices_mapping.values()
|
||||
for id in seq_ids.values()
|
||||
]
|
||||
first_free_index = [
|
||||
i not in occupied for i in range(max_possible_batch_size)
|
||||
].index(True)
|
||||
return first_free_index
|
||||
return 0
|
||||
indices_range = list(range(max_possible_batch_size))
|
||||
all_occupied_indices = self._get_all_occupied_indices()
|
||||
for i in indices_range:
|
||||
if i not in all_occupied_indices:
|
||||
return i
|
||||
raise Exception("Couldn't find a free spot in the mamba cache! This"
|
||||
"should never happen")
|
||||
|
||||
def _get_mamba_cache_shape(
|
||||
self
|
||||
@ -819,20 +848,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
|
||||
[layer_type == "mamba" for layer_type in layers_type])
|
||||
max_batch_size = (_get_graph_batch_size(
|
||||
self.scheduler_config.max_num_seqs) if self.scheduler_config else
|
||||
max(_BATCH_SIZES_TO_CAPTURE)) + 10
|
||||
max(_BATCH_SIZES_TO_CAPTURE) + 2)
|
||||
conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
|
||||
assert conv_state_shape is not None and temporal_state_shape is not None
|
||||
|
||||
for buffername in ["mamba_cache", "mamba_gc_cache_buffer"]:
|
||||
buffer = (torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"),
|
||||
torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"))
|
||||
setattr(self, buffername, buffer)
|
||||
self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
conv_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"),
|
||||
torch.empty(size=(mamba_layers, max_batch_size) +
|
||||
temporal_state_shape,
|
||||
dtype=dtype,
|
||||
device="cuda"))
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
|
||||
@ -1711,9 +1711,6 @@ class CUDAGraphRunner:
|
||||
non_blocking=True)
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||
self.model.copy_outputs_after_cuda_graphs(self.input_buffers,
|
||||
**kwargs)
|
||||
# Return the output tensor.
|
||||
if get_pp_group().is_last_rank:
|
||||
return self.output_buffers["hidden_states"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user