[Core] Use array to speedup padding (#6779)
This commit is contained in:
parent
084a01fd35
commit
89a84b0bb7
@ -220,7 +220,7 @@ def _apply_min_tokens_penalty(
|
|||||||
seqs_to_penalize: List[int] = []
|
seqs_to_penalize: List[int] = []
|
||||||
for j, seq_id in enumerate(seq_ids):
|
for j, seq_id in enumerate(seq_ids):
|
||||||
seq_data = seq_group.seq_data[seq_id]
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
if len(seq_data.output_token_ids) < min_tokens:
|
if len(seq_data.output_token_ids_array) < min_tokens:
|
||||||
seqs_to_penalize.append(j)
|
seqs_to_penalize.append(j)
|
||||||
|
|
||||||
if seqs_to_penalize:
|
if seqs_to_penalize:
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import random
|
import random
|
||||||
|
from array import array
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -329,8 +330,8 @@ class SamplingTensors:
|
|||||||
user-defined seed for each sequence.
|
user-defined seed for each sequence.
|
||||||
extra_entropy: extra entropy to use when generating seeds.
|
extra_entropy: extra entropy to use when generating seeds.
|
||||||
"""
|
"""
|
||||||
prompt_tokens: List[List[int]] = []
|
prompt_tokens: List[array] = []
|
||||||
output_tokens: List[List[int]] = []
|
output_tokens: List[array] = []
|
||||||
top_ks: List[int] = []
|
top_ks: List[int] = []
|
||||||
temperatures: List[float] = []
|
temperatures: List[float] = []
|
||||||
top_ps: List[float] = []
|
top_ps: List[float] = []
|
||||||
@ -432,13 +433,15 @@ class SamplingTensors:
|
|||||||
if (seq_group.is_prompt
|
if (seq_group.is_prompt
|
||||||
and sampling_params.prompt_logprobs is not None):
|
and sampling_params.prompt_logprobs is not None):
|
||||||
prefill_len = len(seq_group.prompt_logprob_indices)
|
prefill_len = len(seq_group.prompt_logprob_indices)
|
||||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
prompt_tokens.extend(
|
||||||
output_tokens.extend([] for _ in range(prefill_len))
|
array('l') for _ in range(prefill_len))
|
||||||
|
output_tokens.extend(
|
||||||
|
array('l') for _ in range(prefill_len))
|
||||||
if seq_group.do_sample:
|
if seq_group.do_sample:
|
||||||
for seq_id in seq_ids:
|
for seq_id in seq_ids:
|
||||||
seq_data = seq_group.seq_data[seq_id]
|
seq_data = seq_group.seq_data[seq_id]
|
||||||
prompt_tokens.append(list(seq_data.prompt_token_ids))
|
prompt_tokens.append(seq_data.prompt_token_ids_array)
|
||||||
output_tokens.append(list(seq_data.output_token_ids))
|
output_tokens.append(seq_data.output_token_ids_array)
|
||||||
|
|
||||||
sampling_tensors = SamplingTensors.from_lists(
|
sampling_tensors = SamplingTensors.from_lists(
|
||||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||||
@ -454,9 +457,9 @@ class SamplingTensors:
|
|||||||
frequency_penalties: List[float],
|
frequency_penalties: List[float],
|
||||||
repetition_penalties: List[float],
|
repetition_penalties: List[float],
|
||||||
sampling_seeds: List[int], sample_indices: List[int],
|
sampling_seeds: List[int], sample_indices: List[int],
|
||||||
prompt_tokens: List[List[int]],
|
prompt_tokens: List[array], output_tokens: List[array],
|
||||||
output_tokens: List[List[int]], vocab_size: int,
|
vocab_size: int, extra_seeds_to_generate: int,
|
||||||
extra_seeds_to_generate: int, device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype) -> "SamplingTensors":
|
dtype: torch.dtype) -> "SamplingTensors":
|
||||||
# Note that the performance will be very bad without
|
# Note that the performance will be very bad without
|
||||||
# pinned memory.
|
# pinned memory.
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import copy
|
|||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from array import array
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||||
@ -119,10 +120,10 @@ class SequenceData:
|
|||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
output_token_ids: Optional[List[int]] = None,
|
output_token_ids: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._prompt_token_ids: List[int] = list(prompt_token_ids)
|
self._prompt_token_ids = array('l', prompt_token_ids)
|
||||||
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
|
||||||
self._output_token_ids: List[int] = (
|
self._output_token_ids = array(
|
||||||
list(output_token_ids) if output_token_ids is not None else [])
|
'l', output_token_ids if output_token_ids is not None else [])
|
||||||
|
|
||||||
self.cumulative_logprob = 0.0
|
self.cumulative_logprob = 0.0
|
||||||
# The number of tokens that are computed (that run against the model).
|
# The number of tokens that are computed (that run against the model).
|
||||||
@ -132,7 +133,7 @@ class SequenceData:
|
|||||||
self._update_cached_all_tokens()
|
self._update_cached_all_tokens()
|
||||||
|
|
||||||
def _update_cached_all_tokens(self):
|
def _update_cached_all_tokens(self):
|
||||||
self._cached_all_token_ids: List[int] = (self._prompt_token_ids +
|
self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
|
||||||
self._output_token_ids)
|
self._output_token_ids)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -141,19 +142,27 @@ class SequenceData:
|
|||||||
|
|
||||||
@prompt_token_ids.setter
|
@prompt_token_ids.setter
|
||||||
def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
def prompt_token_ids(self, new_prompt_token_ids) -> None:
|
||||||
self._prompt_token_ids = list(new_prompt_token_ids)
|
self._prompt_token_ids = array('l', new_prompt_token_ids)
|
||||||
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
|
||||||
self._update_cached_all_tokens()
|
self._update_cached_all_tokens()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_token_ids_array(self) -> array:
|
||||||
|
return self._prompt_token_ids
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids(self) -> Tuple[int, ...]:
|
def output_token_ids(self) -> Tuple[int, ...]:
|
||||||
return tuple(self._output_token_ids)
|
return tuple(self._output_token_ids)
|
||||||
|
|
||||||
@output_token_ids.setter
|
@output_token_ids.setter
|
||||||
def output_token_ids(self, new_output_token_ids) -> None:
|
def output_token_ids(self, new_output_token_ids) -> None:
|
||||||
self._output_token_ids = list(new_output_token_ids)
|
self._output_token_ids = array('l', new_output_token_ids)
|
||||||
self._update_cached_all_tokens()
|
self._update_cached_all_tokens()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_token_ids_array(self) -> array:
|
||||||
|
return self._output_token_ids
|
||||||
|
|
||||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||||
self._output_token_ids.append(token_id)
|
self._output_token_ids.append(token_id)
|
||||||
self._cached_all_token_ids.append(token_id)
|
self._cached_all_token_ids.append(token_id)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user