Enhance SamplingParams (#96)
This commit is contained in:
parent
55f8b0a5de
commit
42f1042e1c
@ -6,7 +6,7 @@ from tqdm import tqdm
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cacheflow.master.server import (
|
from cacheflow.core.server import (
|
||||||
add_server_arguments, process_server_arguments,
|
add_server_arguments, process_server_arguments,
|
||||||
init_local_server_and_frontend_with_arguments)
|
init_local_server_and_frontend_with_arguments)
|
||||||
from cacheflow.sampling_params import SamplingParams
|
from cacheflow.sampling_params import SamplingParams
|
||||||
@ -15,15 +15,14 @@ from cacheflow.sampling_params import SamplingParams
|
|||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
server, frontend = init_local_server_and_frontend_with_arguments(args)
|
||||||
|
|
||||||
sampling_params_dict = {
|
sampling_params = SamplingParams(
|
||||||
'n': args.n,
|
n=args.n,
|
||||||
'temperature': 0.0 if args.use_beam_search else 1.0,
|
temperature=0.0 if args.use_beam_search else 1.0,
|
||||||
'top_p': 1.0,
|
top_p=1.0,
|
||||||
'use_beam_search': args.use_beam_search,
|
use_beam_search=args.use_beam_search,
|
||||||
'stop_token_ids': set(),
|
stop_token_ids=set(),
|
||||||
'max_num_steps': args.output_len,
|
max_tokens=args.output_len,
|
||||||
}
|
)
|
||||||
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
|
||||||
print(sampling_params)
|
print(sampling_params)
|
||||||
input_token_ids = [0] * args.input_len
|
input_token_ids = [0] * args.input_len
|
||||||
|
|
||||||
@ -31,7 +30,8 @@ def main(args: argparse.Namespace):
|
|||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
for _ in range(args.batch_size):
|
for _ in range(args.batch_size):
|
||||||
frontend._add_query(input_token_ids, sampling_params)
|
dummy_prompt = ""
|
||||||
|
frontend._add_query(dummy_prompt, input_token_ids, sampling_params)
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
server.add_sequence_groups(frontend.get_inputs())
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@ -316,7 +316,7 @@ class Scheduler:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if the sequence has reached the maximum number of steps.
|
# Check if the sequence has reached the maximum number of steps.
|
||||||
max_num_steps = self.sampling_params[group_id].max_num_steps
|
max_num_steps = self.sampling_params[group_id].max_tokens
|
||||||
if self.num_steps[group_id] == max_num_steps:
|
if self.num_steps[group_id] == max_num_steps:
|
||||||
self._free_seq(seq)
|
self._free_seq(seq)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -89,8 +89,8 @@ class FastAPIServer:
|
|||||||
|
|
||||||
async def generate(self, request_dict: Dict):
|
async def generate(self, request_dict: Dict):
|
||||||
# Preprocess the request.
|
# Preprocess the request.
|
||||||
prompt = request_dict["prompt"]
|
prompt = request_dict.pop("prompt")
|
||||||
sampling_params = SamplingParams.from_dict(request_dict)
|
sampling_params = SamplingParams(**request_dict)
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||||
token_ids = self.tokenizer.encode(prompt)
|
token_ids = self.tokenizer.encode(prompt)
|
||||||
seqs: List[Sequence] = []
|
seqs: List[Sequence] = []
|
||||||
|
|||||||
@ -367,7 +367,7 @@ def _sample(
|
|||||||
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
next_token_ids = _sample_from_prompt(prob, sampling_params)
|
||||||
# Get top-k log probabilities for the next tokens.
|
# Get top-k log probabilities for the next tokens.
|
||||||
next_logprobs = _get_topk_logprobs(
|
next_logprobs = _get_topk_logprobs(
|
||||||
logprob, sampling_params.num_logprobs)
|
logprob, sampling_params.logprobs)
|
||||||
|
|
||||||
# Build the output.
|
# Build the output.
|
||||||
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
for seq_id, next_token_id in zip(seq_ids, next_token_ids):
|
||||||
@ -392,7 +392,7 @@ def _sample(
|
|||||||
next_logprobs: Dict[int, Dict[int, float]] = {}
|
next_logprobs: Dict[int, Dict[int, float]] = {}
|
||||||
for i, seq_id in enumerate(seq_ids):
|
for i, seq_id in enumerate(seq_ids):
|
||||||
next_logprobs[seq_id] = _get_topk_logprobs(
|
next_logprobs[seq_id] = _get_topk_logprobs(
|
||||||
logprob[i], sampling_params.num_logprobs)
|
logprob[i], sampling_params.logprobs)
|
||||||
|
|
||||||
# Build the output.
|
# Build the output.
|
||||||
for seq_id, parent_seq_id, next_token_id in zip(
|
for seq_id, parent_seq_id, next_token_id in zip(
|
||||||
|
|||||||
@ -5,16 +5,16 @@ class SamplingParams:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n: int,
|
n: int = 1,
|
||||||
presence_penalty: float,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float,
|
frequency_penalty: float = 0.0,
|
||||||
temperature: float,
|
temperature: float = 1.0,
|
||||||
top_p: float,
|
top_p: float = 1.0,
|
||||||
top_k: int,
|
top_k: int = -1,
|
||||||
use_beam_search: bool,
|
use_beam_search: bool = False,
|
||||||
stop_token_ids: Set[int],
|
stop_token_ids: Set[int] = set(),
|
||||||
max_num_steps: int,
|
max_tokens: int = 16,
|
||||||
num_logprobs: int,
|
logprobs: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
if n < 1:
|
if n < 1:
|
||||||
raise ValueError(f"n must be at least 1, got {n}.")
|
raise ValueError(f"n must be at least 1, got {n}.")
|
||||||
@ -32,12 +32,12 @@ class SamplingParams:
|
|||||||
if top_k < -1 or top_k == 0:
|
if top_k < -1 or top_k == 0:
|
||||||
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
raise ValueError(f"top_k must be -1 (disable), or at least 1, "
|
||||||
f"got {top_k}.")
|
f"got {top_k}.")
|
||||||
if max_num_steps < 1:
|
if max_tokens < 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_num_steps must be at least 1, got {max_num_steps}.")
|
f"max_tokens must be at least 1, got {max_tokens}.")
|
||||||
if num_logprobs < 0:
|
if logprobs < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"num_logprobs must be non-negative, got {num_logprobs}.")
|
f"logprobs must be non-negative, got {logprobs}.")
|
||||||
|
|
||||||
if use_beam_search:
|
if use_beam_search:
|
||||||
if n == 1:
|
if n == 1:
|
||||||
@ -72,8 +72,8 @@ class SamplingParams:
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.use_beam_search = use_beam_search
|
self.use_beam_search = use_beam_search
|
||||||
self.stop_token_ids = stop_token_ids
|
self.stop_token_ids = stop_token_ids
|
||||||
self.max_num_steps = max_num_steps
|
self.max_tokens = max_tokens
|
||||||
self.num_logprobs = num_logprobs
|
self.logprobs = logprobs
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SamplingParams(n={self.n}, "
|
return (f"SamplingParams(n={self.n}, "
|
||||||
@ -84,23 +84,5 @@ class SamplingParams:
|
|||||||
f"top_k={self.top_k},"
|
f"top_k={self.top_k},"
|
||||||
f"use_beam_search={self.use_beam_search}, "
|
f"use_beam_search={self.use_beam_search}, "
|
||||||
f"stop_token_ids={self.stop_token_ids}, "
|
f"stop_token_ids={self.stop_token_ids}, "
|
||||||
f"max_num_steps={self.max_num_steps}, "
|
f"max_tokens={self.max_tokens}, "
|
||||||
f"num_logprobs={self.num_logprobs}")
|
f"logprobs={self.logprobs}")
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, d: Dict) -> "SamplingParams":
|
|
||||||
sampling_params = cls(
|
|
||||||
n=d.pop("n", 1),
|
|
||||||
presence_penalty=d.pop("presence_penalty", 0.0),
|
|
||||||
frequency_penalty=d.pop("frequency_penalty", 0.0),
|
|
||||||
temperature=d.pop("temperature", 1.0),
|
|
||||||
top_p=d.pop("top_p", 1.0),
|
|
||||||
top_k=d.pop("top_k", -1),
|
|
||||||
use_beam_search=d.pop("use_beam_search", False),
|
|
||||||
stop_token_ids=set(d.pop("stop_token_ids", set())),
|
|
||||||
max_num_steps=d.pop("max_num_steps", 16),
|
|
||||||
num_logprobs=d.pop("num_logprobs", 0),
|
|
||||||
)
|
|
||||||
if d:
|
|
||||||
raise ValueError(f"Unrecognized keys in dict: {d.keys()}")
|
|
||||||
return sampling_params
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ def http_bot(prompt):
|
|||||||
headers = {"User-Agent": "Cacheflow Client"}
|
headers = {"User-Agent": "Cacheflow Client"}
|
||||||
pload = {
|
pload = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"max_num_steps": 128,
|
"max_tokens": 128,
|
||||||
}
|
}
|
||||||
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ def main(args: argparse.Namespace):
|
|||||||
while True:
|
while True:
|
||||||
if test_inputs:
|
if test_inputs:
|
||||||
text, sampling_params_dict = test_inputs.pop(0)
|
text, sampling_params_dict = test_inputs.pop(0)
|
||||||
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
sampling_params = SamplingParams(**sampling_params_dict)
|
||||||
sampling_params = frontend.add_eos_token(sampling_params)
|
sampling_params = frontend.add_eos_token(sampling_params)
|
||||||
frontend.query(text, sampling_params)
|
frontend.query(text, sampling_params)
|
||||||
server.add_sequence_groups(frontend.get_inputs())
|
server.add_sequence_groups(frontend.get_inputs())
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user