[V1] Support per-request seed (#9945)
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
parent
3bb4befea7
commit
1f1b6d6eda
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -16,7 +16,6 @@ class SamplingMetadata:
|
|||||||
no_top_p: bool
|
no_top_p: bool
|
||||||
no_top_k: bool
|
no_top_k: bool
|
||||||
|
|
||||||
generators: List[Optional[torch.Generator]]
|
generators: Dict[int, torch.Generator]
|
||||||
no_generator: bool
|
|
||||||
|
|
||||||
max_num_logprobs: int
|
max_num_logprobs: int
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""A layer that samples the next tokens from the model's outputs."""
|
"""A layer that samples the next tokens from the model's outputs."""
|
||||||
from typing import List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -84,22 +84,21 @@ class Sampler(nn.Module):
|
|||||||
def random_sample(
|
def random_sample(
|
||||||
self,
|
self,
|
||||||
probs: torch.Tensor,
|
probs: torch.Tensor,
|
||||||
generators: List[Optional[torch.Generator]],
|
generators: Dict[int, torch.Generator],
|
||||||
no_generator: bool,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
q = torch.empty_like(probs)
|
q = torch.empty_like(probs)
|
||||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||||
# which is the common case, we first assume that every request does
|
# which is the common case, we first assume that every request does
|
||||||
# not have its own seed. Then, we overwrite the values for the requests
|
# not have its own seed. Then, we overwrite the values for the requests
|
||||||
# that have their own seeds.
|
# that have their own seeds.
|
||||||
q.exponential_()
|
if len(generators) != probs.shape[0]:
|
||||||
if not no_generator:
|
# This might still be done here unnecessarily if there are greedies
|
||||||
assert len(generators) == probs.shape[0]
|
q.exponential_()
|
||||||
|
if generators:
|
||||||
# TODO(woosuk): This can be slow because we handle each request
|
# TODO(woosuk): This can be slow because we handle each request
|
||||||
# one by one. Optimize this.
|
# one by one. Optimize this.
|
||||||
for i, generator in enumerate(generators):
|
for i, generator in generators.items():
|
||||||
if generator is not None:
|
q[i].exponential_(generator=generator)
|
||||||
q[i].exponential_(generator=generator)
|
|
||||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||||
|
|
||||||
def sample(
|
def sample(
|
||||||
@ -112,13 +111,11 @@ class Sampler(nn.Module):
|
|||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
return self.greedy_sample(probs)
|
return self.greedy_sample(probs)
|
||||||
if sampling_metadata.all_random:
|
if sampling_metadata.all_random:
|
||||||
return self.random_sample(probs, sampling_metadata.generators,
|
return self.random_sample(probs, sampling_metadata.generators)
|
||||||
sampling_metadata.no_generator)
|
|
||||||
|
|
||||||
greedy_sampled = self.greedy_sample(probs)
|
greedy_sampled = self.greedy_sample(probs)
|
||||||
random_sampled = self.random_sample(probs,
|
random_sampled = self.random_sample(probs,
|
||||||
sampling_metadata.generators,
|
sampling_metadata.generators)
|
||||||
sampling_metadata.no_generator)
|
|
||||||
sampled = torch.where(
|
sampled = torch.where(
|
||||||
sampling_metadata.temperature < _SAMPLING_EPS,
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
||||||
greedy_sampled,
|
greedy_sampled,
|
||||||
|
|||||||
@ -128,13 +128,20 @@ class GPUModelRunner:
|
|||||||
# Add new requests to the cached states.
|
# Add new requests to the cached states.
|
||||||
for req_data in scheduler_output.scheduled_new_reqs:
|
for req_data in scheduler_output.scheduled_new_reqs:
|
||||||
req_id = req_data.req_id
|
req_id = req_data.req_id
|
||||||
|
sampling_params = req_data.sampling_params
|
||||||
|
if sampling_params.seed is not None:
|
||||||
|
generator = torch.Generator(device=self.device)
|
||||||
|
generator.manual_seed(sampling_params.seed)
|
||||||
|
else:
|
||||||
|
generator = None
|
||||||
|
|
||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=req_data.prompt_token_ids,
|
prompt_token_ids=req_data.prompt_token_ids,
|
||||||
prompt=req_data.prompt,
|
prompt=req_data.prompt,
|
||||||
multi_modal_data=req_data.multi_modal_data,
|
multi_modal_data=req_data.multi_modal_data,
|
||||||
sampling_params=req_data.sampling_params,
|
sampling_params=sampling_params,
|
||||||
generator=None, # TODO
|
generator=generator,
|
||||||
block_ids=req_data.block_ids,
|
block_ids=req_data.block_ids,
|
||||||
num_computed_tokens=req_data.num_computed_tokens,
|
num_computed_tokens=req_data.num_computed_tokens,
|
||||||
output_token_ids=[],
|
output_token_ids=[],
|
||||||
@ -342,11 +349,9 @@ class GPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
# Ignore the sampled token from the partial request.
|
# Ignore the sampled token from the partial request.
|
||||||
# Rewind the generator state as if the token was not sampled.
|
# Rewind the generator state as if the token was not sampled.
|
||||||
generator = self.input_batch.generators[i]
|
generator = self.input_batch.generators.get(i)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
offset = generator.get_offset()
|
generator.set_offset(generator.get_offset() - 1)
|
||||||
generator = generator.set_offset(offset - 1)
|
|
||||||
self.input_batch.generators[i] = generator
|
|
||||||
|
|
||||||
if sampler_output.logprob_token_ids is None:
|
if sampler_output.logprob_token_ids is None:
|
||||||
logprob_token_ids = None
|
logprob_token_ids = None
|
||||||
@ -494,8 +499,8 @@ class InputBatch:
|
|||||||
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
|
||||||
self.top_k_reqs: Set[str] = set()
|
self.top_k_reqs: Set[str] = set()
|
||||||
|
|
||||||
self.generators: List[Optional[torch.Generator]] = [None
|
# req_index -> generator
|
||||||
] * max_num_reqs
|
self.generators: Dict[int, torch.Generator] = {}
|
||||||
|
|
||||||
self.num_logprobs: Dict[str, int] = {}
|
self.num_logprobs: Dict[str, int] = {}
|
||||||
self.prompt_logprob_reqs: Set[str] = set()
|
self.prompt_logprob_reqs: Set[str] = set()
|
||||||
@ -509,8 +514,9 @@ class InputBatch:
|
|||||||
req_index = self.num_reqs
|
req_index = self.num_reqs
|
||||||
assert req_index < self.max_num_reqs
|
assert req_index < self.max_num_reqs
|
||||||
|
|
||||||
self.req_ids[req_index] = request.req_id
|
req_id = request.req_id
|
||||||
self.req_id_to_index[request.req_id] = req_index
|
self.req_ids[req_index] = req_id
|
||||||
|
self.req_id_to_index[req_id] = req_index
|
||||||
|
|
||||||
# Copy the prompt token ids and output token ids.
|
# Copy the prompt token ids and output token ids.
|
||||||
num_prompt_tokens = len(request.prompt_token_ids)
|
num_prompt_tokens = len(request.prompt_token_ids)
|
||||||
@ -528,27 +534,24 @@ class InputBatch:
|
|||||||
sampling_params = request.sampling_params
|
sampling_params = request.sampling_params
|
||||||
self.temperature_cpu[req_index] = sampling_params.temperature
|
self.temperature_cpu[req_index] = sampling_params.temperature
|
||||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||||
self.greedy_reqs.add(req_index)
|
self.greedy_reqs.add(req_id)
|
||||||
elif sampling_params.sampling_type == SamplingType.RANDOM:
|
else:
|
||||||
self.random_reqs.add(req_index)
|
self.random_reqs.add(req_id)
|
||||||
elif sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
|
||||||
# TODO(woosuk): Support per-request random seed.
|
|
||||||
raise NotImplementedError("Per-request seed is not supported yet.")
|
|
||||||
|
|
||||||
self.top_p_cpu[req_index] = sampling_params.top_p
|
self.top_p_cpu[req_index] = sampling_params.top_p
|
||||||
if sampling_params.top_p < 1:
|
if sampling_params.top_p < 1:
|
||||||
self.top_p_reqs.add(req_index)
|
self.top_p_reqs.add(req_id)
|
||||||
self.top_k_cpu[req_index] = sampling_params.top_k
|
self.top_k_cpu[req_index] = sampling_params.top_k
|
||||||
if sampling_params.top_k > 0:
|
if sampling_params.top_k > 0:
|
||||||
self.top_k_reqs.add(req_index)
|
self.top_k_reqs.add(req_id)
|
||||||
|
|
||||||
self.generators[req_index] = request.generator
|
self.generators[req_index] = request.generator
|
||||||
|
|
||||||
num_logprobs = sampling_params.logprobs
|
num_logprobs = sampling_params.logprobs
|
||||||
if num_logprobs is not None and num_logprobs > 0:
|
if num_logprobs is not None and num_logprobs > 0:
|
||||||
self.num_logprobs[request.req_id] = num_logprobs
|
self.num_logprobs[req_id] = num_logprobs
|
||||||
if sampling_params.prompt_logprobs:
|
if sampling_params.prompt_logprobs:
|
||||||
self.prompt_logprob_reqs.add(req_index)
|
self.prompt_logprob_reqs.add(req_id)
|
||||||
|
|
||||||
def remove_request(self, req_id: str) -> Optional[int]:
|
def remove_request(self, req_id: str) -> Optional[int]:
|
||||||
req_index = self.req_id_to_index.pop(req_id, None)
|
req_index = self.req_id_to_index.pop(req_id, None)
|
||||||
@ -560,7 +563,7 @@ class InputBatch:
|
|||||||
self.random_reqs.discard(req_id)
|
self.random_reqs.discard(req_id)
|
||||||
self.top_p_reqs.discard(req_id)
|
self.top_p_reqs.discard(req_id)
|
||||||
self.top_k_reqs.discard(req_id)
|
self.top_k_reqs.discard(req_id)
|
||||||
self.generators[req_index] = None
|
self.generators.pop(req_index, None)
|
||||||
self.num_logprobs.pop(req_id, None)
|
self.num_logprobs.pop(req_id, None)
|
||||||
self.prompt_logprob_reqs.discard(req_id)
|
self.prompt_logprob_reqs.discard(req_id)
|
||||||
return req_index
|
return req_index
|
||||||
@ -612,7 +615,9 @@ class InputBatch:
|
|||||||
last_req_index]
|
last_req_index]
|
||||||
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
|
||||||
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
|
||||||
self.generators[empty_index] = self.generators[last_req_index]
|
generator = self.generators.pop(last_req_index, None)
|
||||||
|
if generator is not None:
|
||||||
|
self.generators[empty_index] = generator
|
||||||
|
|
||||||
# Decrement last_req_index since it is now empty.
|
# Decrement last_req_index since it is now empty.
|
||||||
last_req_index -= 1
|
last_req_index -= 1
|
||||||
@ -636,8 +641,7 @@ class InputBatch:
|
|||||||
top_k=self.top_k[:self.num_reqs],
|
top_k=self.top_k[:self.num_reqs],
|
||||||
no_top_p=self.no_top_p,
|
no_top_p=self.no_top_p,
|
||||||
no_top_k=self.no_top_k,
|
no_top_k=self.no_top_k,
|
||||||
generators=self.generators[:self.num_reqs],
|
generators=self.generators,
|
||||||
no_generator=self.no_generator,
|
|
||||||
max_num_logprobs=self.max_num_logprobs,
|
max_num_logprobs=self.max_num_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -661,16 +665,9 @@ class InputBatch:
|
|||||||
def no_top_k(self) -> bool:
|
def no_top_k(self) -> bool:
|
||||||
return len(self.top_k_reqs) == 0
|
return len(self.top_k_reqs) == 0
|
||||||
|
|
||||||
@property
|
|
||||||
def no_generator(self) -> bool:
|
|
||||||
return len(self.generators) == 0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_num_logprobs(self) -> int:
|
def max_num_logprobs(self) -> int:
|
||||||
if self.num_logprobs:
|
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||||
return max(self.num_logprobs.values())
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def no_logprob(self) -> bool:
|
def no_logprob(self) -> bool:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user