Clean up the server script

This commit is contained in:
Woosuk Kwon 2023-02-24 11:56:21 +00:00
parent 6aef2278f4
commit fa16389a2e

View File

@ -1,6 +1,7 @@
import argparse import argparse
from typing import List from typing import List
from cacheflow.master.frontend import Frontend
from cacheflow.master.scheduler import Scheduler from cacheflow.master.scheduler import Scheduler
from cacheflow.worker.controller import Controller from cacheflow.worker.controller import Controller
@ -8,15 +9,15 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
parser.add_argument('--block-size', type=int, default=8, help='block size') parser.add_argument('--block-size', type=int, default=8, help='token block size')
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks') # TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks') parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
args = parser.parse_args()
def main(): def main():
args = parser.parse_args() # Create a controller for each node.
# Create controllers.
controllers: List[Controller] = [] controllers: List[Controller] = []
for i in range(args.num_nodes): for i in range(args.num_nodes):
controller = Controller( controller = Controller(
@ -26,12 +27,18 @@ def main():
block_size=args.block_size, block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
num_cpu_blocks=args.num_cpu_blocks, num_cpu_blocks=args.num_cpu_blocks,
dtype='float',
) )
controllers.append(controller) controllers.append(controller)
# Create a frontend.
frontend = Frontend(
model_name=args.model,
block_size=args.block_size,
)
# Create a scheduler. # Create a scheduler.
scheduler = Scheduler( scheduler = Scheduler(
frontend=frontend,
controllers=controllers, controllers=controllers,
block_size=args.block_size, block_size=args.block_size,
num_gpu_blocks=args.num_gpu_blocks, num_gpu_blocks=args.num_gpu_blocks,
@ -42,65 +49,19 @@ def main():
controllers[i].set_next(controllers[i + 1]) controllers[i].set_next(controllers[i + 1])
controllers[-1].set_next(scheduler) controllers[-1].set_next(scheduler)
# seq_groups, max_num_steps, stop_token_ids = generate_inputs(1000, args.block_size) test_inputs = [
seq_groups, max_num_steps, stop_token_ids = test_inputs(args.block_size) 'Ion Stoica is a',
scheduler.pending.extend(seq_groups) 'UC Berkeley is',
scheduler.max_num_steps.update(max_num_steps) 'The future of cloud computing is',
scheduler.stop_token_ids.update(stop_token_ids) ]
for prompt in test_inputs:
frontend.query(prompt)
while scheduler.pending or scheduler.running: # FIXME
scheduler.prepare() while True:
scheduler.step() scheduler.step()
if not scheduler.pending and not scheduler.running:
break
def test_inputs(block_size):
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
prompt = "Hello, I'm am conscious and"
prompt_tokens = tokenizer.encode(prompt)
seq = Sequence(0, prompt_tokens, block_size=block_size)
seq_group = SequenceGroup(0, [seq])
seq_groups = [seq_group]
max_num_steps = {0: 8}
stop_token_ids = {0: []}
return seq_groups, max_num_steps, stop_token_ids
def generate_inputs(num_inputs, block_size):
import random
random.seed(0)
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.utils import Counter
seq_group_counter = Counter()
seq_counter = Counter()
max_num_steps = {}
stop_token_ids = {}
seq_groups = []
for _ in range(num_inputs):
seq_group_id = next(seq_group_counter)
prompt_len = random.randint(16, 128)
max_num_steps[seq_group_id] = random.randint(32, 1024)
stop_token_ids[seq_group_id] = []
seqs = []
for _ in range(2):
seq_id = next(seq_counter)
seq = Sequence(seq_id, [0] * prompt_len, block_size=block_size)
seqs.append(seq)
seq_group = SequenceGroup(seq_group_id, seqs)
seq_groups.append(seq_group)
return seq_groups, max_num_steps, stop_token_ids
if __name__ == '__main__': if __name__ == '__main__':