[V1] Support per-request seed (#9945)
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
3bb4befea7
commit
1f1b6d6eda
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
@ -16,7 +16,6 @@ class SamplingMetadata:
|
||||
no_top_p: bool
|
||||
no_top_k: bool
|
||||
|
||||
generators: List[Optional[torch.Generator]]
|
||||
no_generator: bool
|
||||
generators: Dict[int, torch.Generator]
|
||||
|
||||
max_num_logprobs: int
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import List, Optional
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -84,22 +84,21 @@ class Sampler(nn.Module):
|
||||
def random_sample(
|
||||
self,
|
||||
probs: torch.Tensor,
|
||||
generators: List[Optional[torch.Generator]],
|
||||
no_generator: bool,
|
||||
generators: Dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
q.exponential_()
|
||||
if not no_generator:
|
||||
assert len(generators) == probs.shape[0]
|
||||
if len(generators) != probs.shape[0]:
|
||||
# This might still be done here unnecessarily if there are greedies
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in enumerate(generators):
|
||||
if generator is not None:
|
||||
q[i].exponential_(generator=generator)
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
@ -112,13 +111,11 @@ class Sampler(nn.Module):
|
||||
if sampling_metadata.all_greedy:
|
||||
return self.greedy_sample(probs)
|
||||
if sampling_metadata.all_random:
|
||||
return self.random_sample(probs, sampling_metadata.generators,
|
||||
sampling_metadata.no_generator)
|
||||
return self.random_sample(probs, sampling_metadata.generators)
|
||||
|
||||
greedy_sampled = self.greedy_sample(probs)
|
||||
random_sampled = self.random_sample(probs,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.no_generator)
|
||||
sampling_metadata.generators)
|
||||
sampled = torch.where(
|
||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled,
|
||||
|
||||
@ -128,13 +128,20 @@ class GPUModelRunner:
|
||||
# Add new requests to the cached states.
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req_data.req_id
|
||||
sampling_params = req_data.sampling_params
|
||||
if sampling_params.seed is not None:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=req_data.prompt_token_ids,
|
||||
prompt=req_data.prompt,
|
||||
multi_modal_data=req_data.multi_modal_data,
|
||||
sampling_params=req_data.sampling_params,
|
||||
generator=None, # TODO
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=req_data.block_ids,
|
||||
num_computed_tokens=req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
@ -342,11 +349,9 @@ class GPUModelRunner:
|
||||
else:
|
||||
# Ignore the sampled token from the partial request.
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators[i]
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
offset = generator.get_offset()
|
||||
generator = generator.set_offset(offset - 1)
|
||||
self.input_batch.generators[i] = generator
|
||||
generator.set_offset(generator.get_offset() - 1)
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
@ -494,8 +499,8 @@ class InputBatch:
|
||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||
self.top_k_reqs: Set[str] = set()
|
||||
|
||||
self.generators: List[Optional[torch.Generator]] = [None
|
||||
] * max_num_reqs
|
||||
# req_index -> generator
|
||||
self.generators: Dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: Dict[str, int] = {}
|
||||
self.prompt_logprob_reqs: Set[str] = set()
|
||||
@ -509,8 +514,9 @@ class InputBatch:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
self.req_ids[req_index] = request.req_id
|
||||
self.req_id_to_index[request.req_id] = req_index
|
||||
req_id = request.req_id
|
||||
self.req_ids[req_index] = req_id
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
@ -528,27 +534,24 @@ class InputBatch:
|
||||
sampling_params = request.sampling_params
|
||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
self.greedy_reqs.add(req_index)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
||||
self.random_reqs.add(req_index)
|
||||
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
# TODO(woosuk): Support per-request random seed.
|
||||
raise NotImplementedError("Per-request seed is not supported yet.")
|
||||
self.greedy_reqs.add(req_id)
|
||||
else:
|
||||
self.random_reqs.add(req_id)
|
||||
|
||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||
if sampling_params.top_p < 1:
|
||||
self.top_p_reqs.add(req_index)
|
||||
self.top_p_reqs.add(req_id)
|
||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||
if sampling_params.top_k > 0:
|
||||
self.top_k_reqs.add(req_index)
|
||||
self.top_k_reqs.add(req_id)
|
||||
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is not None and num_logprobs > 0:
|
||||
self.num_logprobs[request.req_id] = num_logprobs
|
||||
self.num_logprobs[req_id] = num_logprobs
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_index)
|
||||
self.prompt_logprob_reqs.add(req_id)
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
@ -560,7 +563,7 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.generators[req_index] = None
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
return req_index
|
||||
@ -612,7 +615,9 @@ class InputBatch:
|
||||
last_req_index]
|
||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||
self.generators[empty_index] = self.generators[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
@ -636,8 +641,7 @@ class InputBatch:
|
||||
top_k=self.top_k[:self.num_reqs],
|
||||
no_top_p=self.no_top_p,
|
||||
no_top_k=self.no_top_k,
|
||||
generators=self.generators[:self.num_reqs],
|
||||
no_generator=self.no_generator,
|
||||
generators=self.generators,
|
||||
max_num_logprobs=self.max_num_logprobs,
|
||||
)
|
||||
|
||||
@ -661,16 +665,9 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_generator(self) -> bool:
|
||||
return len(self.generators) == 0
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
if self.num_logprobs:
|
||||
return max(self.num_logprobs.values())
|
||||
else:
|
||||
return 0
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||
|
||||
@property
|
||||
def no_logprob(self) -> bool:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user