[Core] Avoid copying prompt/output tokens if no penalties are used (#5289)
This commit is contained in:
parent
828da0d44e
commit
a31cab7556
@ -386,16 +386,18 @@ class SamplingTensors:
|
||||
presence_penalties += [0] * prefill_len
|
||||
frequency_penalties += [0] * prefill_len
|
||||
repetition_penalties += [1] * prefill_len
|
||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||
output_tokens.extend([] for _ in range(prefill_len))
|
||||
if do_penalties:
|
||||
prompt_tokens.extend([] for _ in range(prefill_len))
|
||||
output_tokens.extend([] for _ in range(prefill_len))
|
||||
|
||||
if seq_group.do_sample:
|
||||
sample_lens = len(seq_group.sample_indices)
|
||||
assert sample_lens == len(seq_ids)
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group.seq_data[seq_id]
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
if do_penalties:
|
||||
prompt_tokens.append(seq_data.prompt_token_ids)
|
||||
output_tokens.append(seq_data.output_token_ids)
|
||||
temperatures += [temperature] * len(seq_ids)
|
||||
top_ps += [top_p] * len(seq_ids)
|
||||
top_ks += [top_k] * len(seq_ids)
|
||||
@ -443,18 +445,22 @@ class SamplingTensors:
|
||||
# Note that the performance will be very bad without
|
||||
# pinned memory.
|
||||
pin_memory = is_pin_memory_available()
|
||||
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||
default=0)
|
||||
prompt_padded_tokens = [
|
||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||
for tokens in prompt_tokens
|
||||
]
|
||||
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||
default=0)
|
||||
output_padded_tokens = [
|
||||
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
|
||||
do_penalties = prompt_tokens or output_tokens
|
||||
|
||||
if do_penalties:
|
||||
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||
default=0)
|
||||
prompt_padded_tokens = [
|
||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||
for tokens in prompt_tokens
|
||||
]
|
||||
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||
default=0)
|
||||
output_padded_tokens = [
|
||||
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
|
||||
temperatures_t = torch.tensor(
|
||||
temperatures,
|
||||
@ -504,18 +510,22 @@ class SamplingTensors:
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
prompt_tensor = torch.tensor(
|
||||
prompt_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
output_tensor = torch.tensor(
|
||||
output_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
if do_penalties:
|
||||
prompt_tensor = torch.tensor(
|
||||
prompt_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
output_tensor = torch.tensor(
|
||||
output_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
else:
|
||||
prompt_tensor = None
|
||||
output_tensor = None
|
||||
# need to transpose and make contiguous to
|
||||
# copy the tensor correctly.
|
||||
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
||||
@ -538,6 +548,16 @@ class SamplingTensors:
|
||||
extra_seeds_gpu = None
|
||||
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
||||
|
||||
if do_penalties:
|
||||
prompt_tokens_gpu = prompt_tensor.to(device=device,
|
||||
non_blocking=True)
|
||||
output_tokens_gpu = output_tensor.to(device=device,
|
||||
non_blocking=True)
|
||||
else:
|
||||
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
|
||||
prompt_tokens_gpu = empty_tensor
|
||||
output_tokens_gpu = empty_tensor
|
||||
|
||||
return cls(
|
||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||
@ -549,8 +569,8 @@ class SamplingTensors:
|
||||
non_blocking=True),
|
||||
repetition_penalties=repetition_penalties_t.to(device=device,
|
||||
non_blocking=True),
|
||||
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
||||
output_tokens=output_tensor.to(device=device, non_blocking=True),
|
||||
prompt_tokens=prompt_tokens_gpu,
|
||||
output_tokens=output_tokens_gpu,
|
||||
sampling_seeds=sampling_seeds_gpu,
|
||||
sample_indices=sample_indices_t.to(device=device,
|
||||
non_blocking=True),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user