[V1] Support per-request seed (#9945)

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
Nick Hill 2024-11-03 17:14:17 +00:00 committed by GitHub
parent 3bb4befea7
commit 1f1b6d6eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 48 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional from typing import Dict
import torch import torch
@ -16,7 +16,6 @@ class SamplingMetadata:
no_top_p: bool no_top_p: bool
no_top_k: bool no_top_k: bool
generators: List[Optional[torch.Generator]] generators: Dict[int, torch.Generator]
no_generator: bool
max_num_logprobs: int max_num_logprobs: int

View File

@ -1,5 +1,5 @@
"""A layer that samples the next tokens from the model's outputs.""" """A layer that samples the next tokens from the model's outputs."""
from typing import List, Optional from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -84,22 +84,21 @@ class Sampler(nn.Module):
def random_sample( def random_sample(
self, self,
probs: torch.Tensor, probs: torch.Tensor,
generators: List[Optional[torch.Generator]], generators: Dict[int, torch.Generator],
no_generator: bool,
) -> torch.Tensor: ) -> torch.Tensor:
q = torch.empty_like(probs) q = torch.empty_like(probs)
# NOTE(woosuk): To batch-process the requests without their own seeds, # NOTE(woosuk): To batch-process the requests without their own seeds,
# which is the common case, we first assume that every request does # which is the common case, we first assume that every request does
# not have its own seed. Then, we overwrite the values for the requests # not have its own seed. Then, we overwrite the values for the requests
# that have their own seeds. # that have their own seeds.
q.exponential_() if len(generators) != probs.shape[0]:
if not no_generator: # This might still be done here unnecessarily if there are greedies
assert len(generators) == probs.shape[0] q.exponential_()
if generators:
# TODO(woosuk): This can be slow because we handle each request # TODO(woosuk): This can be slow because we handle each request
# one by one. Optimize this. # one by one. Optimize this.
for i, generator in enumerate(generators): for i, generator in generators.items():
if generator is not None: q[i].exponential_(generator=generator)
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1) return probs.div_(q).argmax(dim=-1).view(-1)
def sample( def sample(
@ -112,13 +111,11 @@ class Sampler(nn.Module):
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return self.greedy_sample(probs) return self.greedy_sample(probs)
if sampling_metadata.all_random: if sampling_metadata.all_random:
return self.random_sample(probs, sampling_metadata.generators, return self.random_sample(probs, sampling_metadata.generators)
sampling_metadata.no_generator)
greedy_sampled = self.greedy_sample(probs) greedy_sampled = self.greedy_sample(probs)
random_sampled = self.random_sample(probs, random_sampled = self.random_sample(probs,
sampling_metadata.generators, sampling_metadata.generators)
sampling_metadata.no_generator)
sampled = torch.where( sampled = torch.where(
sampling_metadata.temperature < _SAMPLING_EPS, sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, greedy_sampled,

View File

@ -128,13 +128,20 @@ class GPUModelRunner:
# Add new requests to the cached states. # Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs: for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.seed is not None:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
self.requests[req_id] = CachedRequestState( self.requests[req_id] = CachedRequestState(
req_id=req_id, req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids, prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt, prompt=req_data.prompt,
multi_modal_data=req_data.multi_modal_data, multi_modal_data=req_data.multi_modal_data,
sampling_params=req_data.sampling_params, sampling_params=sampling_params,
generator=None, # TODO generator=generator,
block_ids=req_data.block_ids, block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens, num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[], output_token_ids=[],
@ -342,11 +349,9 @@ class GPUModelRunner:
else: else:
# Ignore the sampled token from the partial request. # Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled. # Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators[i] generator = self.input_batch.generators.get(i)
if generator is not None: if generator is not None:
offset = generator.get_offset() generator.set_offset(generator.get_offset() - 1)
generator = generator.set_offset(offset - 1)
self.input_batch.generators[i] = generator
if sampler_output.logprob_token_ids is None: if sampler_output.logprob_token_ids is None:
logprob_token_ids = None logprob_token_ids = None
@ -494,8 +499,8 @@ class InputBatch:
self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set() self.top_k_reqs: Set[str] = set()
self.generators: List[Optional[torch.Generator]] = [None # req_index -> generator
] * max_num_reqs self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {} self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set() self.prompt_logprob_reqs: Set[str] = set()
@ -509,8 +514,9 @@ class InputBatch:
req_index = self.num_reqs req_index = self.num_reqs
assert req_index < self.max_num_reqs assert req_index < self.max_num_reqs
self.req_ids[req_index] = request.req_id req_id = request.req_id
self.req_id_to_index[request.req_id] = req_index self.req_ids[req_index] = req_id
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids. # Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids) num_prompt_tokens = len(request.prompt_token_ids)
@ -528,27 +534,24 @@ class InputBatch:
sampling_params = request.sampling_params sampling_params = request.sampling_params
self.temperature_cpu[req_index] = sampling_params.temperature self.temperature_cpu[req_index] = sampling_params.temperature
if sampling_params.sampling_type == SamplingType.GREEDY: if sampling_params.sampling_type == SamplingType.GREEDY:
self.greedy_reqs.add(req_index) self.greedy_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM: else:
self.random_reqs.add(req_index) self.random_reqs.add(req_id)
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
# TODO(woosuk): Support per-request random seed.
raise NotImplementedError("Per-request seed is not supported yet.")
self.top_p_cpu[req_index] = sampling_params.top_p self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1: if sampling_params.top_p < 1:
self.top_p_reqs.add(req_index) self.top_p_reqs.add(req_id)
self.top_k_cpu[req_index] = sampling_params.top_k self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0: if sampling_params.top_k > 0:
self.top_k_reqs.add(req_index) self.top_k_reqs.add(req_id)
self.generators[req_index] = request.generator self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0: if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[request.req_id] = num_logprobs self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs: if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_index) self.prompt_logprob_reqs.add(req_id)
def remove_request(self, req_id: str) -> Optional[int]: def remove_request(self, req_id: str) -> Optional[int]:
req_index = self.req_id_to_index.pop(req_id, None) req_index = self.req_id_to_index.pop(req_id, None)
@ -560,7 +563,7 @@ class InputBatch:
self.random_reqs.discard(req_id) self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id) self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id) self.top_k_reqs.discard(req_id)
self.generators[req_index] = None self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None) self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id) self.prompt_logprob_reqs.discard(req_id)
return req_index return req_index
@ -612,7 +615,9 @@ class InputBatch:
last_req_index] last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.generators[empty_index] = self.generators[last_req_index] generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
# Decrement last_req_index since it is now empty. # Decrement last_req_index since it is now empty.
last_req_index -= 1 last_req_index -= 1
@ -636,8 +641,7 @@ class InputBatch:
top_k=self.top_k[:self.num_reqs], top_k=self.top_k[:self.num_reqs],
no_top_p=self.no_top_p, no_top_p=self.no_top_p,
no_top_k=self.no_top_k, no_top_k=self.no_top_k,
generators=self.generators[:self.num_reqs], generators=self.generators,
no_generator=self.no_generator,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
) )
@ -661,16 +665,9 @@ class InputBatch:
def no_top_k(self) -> bool: def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0 return len(self.top_k_reqs) == 0
@property
def no_generator(self) -> bool:
return len(self.generators) == 0
@property @property
def max_num_logprobs(self) -> int: def max_num_logprobs(self) -> int:
if self.num_logprobs: return max(self.num_logprobs.values()) if self.num_logprobs else 0
return max(self.num_logprobs.values())
else:
return 0
@property @property
def no_logprob(self) -> bool: def no_logprob(self) -> bool: