[V1] Support per-request seed (#9945)

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Nick Hill 2024-11-03 17:14:17 +00:00 committed by GitHub
parent 3bb4befea7
commit 1f1b6d6eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 48 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict
import torch
@ -16,7 +16,6 @@ class SamplingMetadata:
no_top_p: bool
no_top_k: bool
generators: List[Optional[torch.Generator]]
no_generator: bool
generators: Dict[int, torch.Generator]
max_num_logprobs: int

View File

@ -1,5 +1,5 @@
"""A layer that samples the next tokens from the model's outputs."""
from typing import List, Optional
from typing import Dict
import torch
import torch.nn as nn
@ -84,22 +84,21 @@ class Sampler(nn.Module):
def random_sample(
self,
probs: torch.Tensor,
generators: List[Optional[torch.Generator]],
no_generator: bool,
generators: Dict[int, torch.Generator],
) -> torch.Tensor:
q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds.
q.exponential_()
if not no_generator:
assert len(generators) == probs.shape[0]
if len(generators) != probs.shape[0]:
# This might still be done here unnecessarily if there are greedies
q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this.
for i, generator in enumerate(generators):
if generator is not None:
q[i].exponential_(generator=generator)
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1)
def sample(
@ -112,13 +111,11 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy:
return self.greedy_sample(probs)
if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators,
sampling_metadata.no_generator)
return self.random_sample(probs, sampling_metadata.generators)
greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs,
sampling_metadata.generators,
sampling_metadata.no_generator)
sampling_metadata.generators)
sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled,

View File

@ -128,13 +128,20 @@ class GPUModelRunner:
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.seed is not None:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data,
sampling_params=req_data.sampling_params,
generator=None, # TODO
sampling_params=sampling_params,
generator=generator,
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
@ -342,11 +349,9 @@ class GPUModelRunner:
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators[i]
generator = self.input_batch.generators.get(i)
if generator is not None:
offset = generator.get_offset()
generator = generator.set_offset(offset - 1)
self.input_batch.generators[i] = generator
generator.set_offset(generator.get_offset() - 1)
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
@ -494,8 +499,8 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()
self.generators: List[Optional[torch.Generator]] = [None
] * max_num_reqs
# req_index -> generator
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
@ -509,8 +514,9 @@ class InputBatch:
req_index = self.num_reqs
assert req_index < self.max_num_reqs
self.req_ids[req_index] = request.req_id
self.req_id_to_index[request.req_id] = req_index
req_id = request.req_id
self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
@ -528,27 +534,24 @@ class InputBatch:
sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM:
self.random_reqs.add(req_index)
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
# TODO(woosuk): Support per-request random seed.
raise NotImplementedError("Per-request seed is not supported yet.")
self.greedy_reqs.add(req_id)
else:
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_index)
self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_index)
self.top_k_reqs.add(req_id)
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[request.req_id] = num_logprobs
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_index)
self.prompt_logprob_reqs.add(req_id)
def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None)
@ -560,7 +563,7 @@ class InputBatch:
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.generators[req_index] = None
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
return req_index
@ -612,7 +615,9 @@ class InputBatch:
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.generators[empty_index] = self.generators[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@ -636,8 +641,7 @@ class InputBatch:
top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators[:self.num_reqs],
no_generator=self.no_generator,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
)
@ -661,16 +665,9 @@ class InputBatch:
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_generator(self) -> bool:
return len(self.generators) == 0
@property
def max_num_logprobs(self) -> int:
if self.num_logprobs:
return max(self.num_logprobs.values())
else:
return 0
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool: