[Core] Use array to speedup padding (#6779)

This commit is contained in:
Peng Guanwen 2024-07-26 12:31:31 +08:00 committed by GitHub
parent 084a01fd35
commit 89a84b0bb7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 29 additions and 17 deletions

View File

@ -220,7 +220,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize: List[int] = [] seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens: if len(seq_data.output_token_ids_array) < min_tokens:
seqs_to_penalize.append(j) seqs_to_penalize.append(j)
if seqs_to_penalize: if seqs_to_penalize:

View File

@ -1,4 +1,5 @@
import random import random
from array import array
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -329,8 +330,8 @@ class SamplingTensors:
user-defined seed for each sequence. user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds. extra_entropy: extra entropy to use when generating seeds.
""" """
prompt_tokens: List[List[int]] = [] prompt_tokens: List[array] = []
output_tokens: List[List[int]] = [] output_tokens: List[array] = []
top_ks: List[int] = [] top_ks: List[int] = []
temperatures: List[float] = [] temperatures: List[float] = []
top_ps: List[float] = [] top_ps: List[float] = []
@ -432,13 +433,15 @@ class SamplingTensors:
if (seq_group.is_prompt if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices) prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend([] for _ in range(prefill_len)) prompt_tokens.extend(
output_tokens.extend([] for _ in range(prefill_len)) array('l') for _ in range(prefill_len))
output_tokens.extend(
array('l') for _ in range(prefill_len))
if seq_group.do_sample: if seq_group.do_sample:
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(list(seq_data.prompt_token_ids)) prompt_tokens.append(seq_data.prompt_token_ids_array)
output_tokens.append(list(seq_data.output_token_ids)) output_tokens.append(seq_data.output_token_ids_array)
sampling_tensors = SamplingTensors.from_lists( sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties, temperatures, top_ps, top_ks, min_ps, presence_penalties,
@ -454,9 +457,9 @@ class SamplingTensors:
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float], repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int], sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[List[int]], prompt_tokens: List[array], output_tokens: List[array],
output_tokens: List[List[int]], vocab_size: int, vocab_size: int, extra_seeds_to_generate: int,
extra_seeds_to_generate: int, device: torch.device, device: torch.device,
dtype: torch.dtype) -> "SamplingTensors": dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without # Note that the performance will be very bad without
# pinned memory. # pinned memory.

View File

@ -3,6 +3,7 @@ import copy
import enum import enum
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from array import array
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
@ -119,10 +120,10 @@ class SequenceData:
prompt_token_ids: List[int], prompt_token_ids: List[int],
output_token_ids: Optional[List[int]] = None, output_token_ids: Optional[List[int]] = None,
) -> None: ) -> None:
self._prompt_token_ids: List[int] = list(prompt_token_ids) self._prompt_token_ids = array('l', prompt_token_ids)
self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids) self._prompt_token_ids_tuple: Tuple[int, ...] = tuple(prompt_token_ids)
self._output_token_ids: List[int] = ( self._output_token_ids = array(
list(output_token_ids) if output_token_ids is not None else []) 'l', output_token_ids if output_token_ids is not None else [])
self.cumulative_logprob = 0.0 self.cumulative_logprob = 0.0
# The number of tokens that are computed (that run against the model). # The number of tokens that are computed (that run against the model).
@ -132,7 +133,7 @@ class SequenceData:
self._update_cached_all_tokens() self._update_cached_all_tokens()
def _update_cached_all_tokens(self): def _update_cached_all_tokens(self):
self._cached_all_token_ids: List[int] = (self._prompt_token_ids + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids +
self._output_token_ids) self._output_token_ids)
@property @property
@ -141,19 +142,27 @@ class SequenceData:
@prompt_token_ids.setter @prompt_token_ids.setter
def prompt_token_ids(self, new_prompt_token_ids) -> None: def prompt_token_ids(self, new_prompt_token_ids) -> None:
self._prompt_token_ids = list(new_prompt_token_ids) self._prompt_token_ids = array('l', new_prompt_token_ids)
self._prompt_token_ids_tuple = tuple(new_prompt_token_ids) self._prompt_token_ids_tuple = tuple(new_prompt_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property
def prompt_token_ids_array(self) -> array:
return self._prompt_token_ids
@property @property
def output_token_ids(self) -> Tuple[int, ...]: def output_token_ids(self) -> Tuple[int, ...]:
return tuple(self._output_token_ids) return tuple(self._output_token_ids)
@output_token_ids.setter @output_token_ids.setter
def output_token_ids(self, new_output_token_ids) -> None: def output_token_ids(self, new_output_token_ids) -> None:
self._output_token_ids = list(new_output_token_ids) self._output_token_ids = array('l', new_output_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property
def output_token_ids_array(self) -> array:
return self._output_token_ids
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self, token_id: int, logprob: float) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)