diff --git a/server.py b/server.py index 0727217f..a0e12ab1 100644 --- a/server.py +++ b/server.py @@ -1,6 +1,7 @@ import argparse from typing import List +from cacheflow.master.frontend import Frontend from cacheflow.master.scheduler import Scheduler from cacheflow.worker.controller import Controller @@ -8,15 +9,15 @@ parser = argparse.ArgumentParser(description='CacheFlow server') parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes') parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node') -parser.add_argument('--block-size', type=int, default=8, help='block size') -parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks') -parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks') +parser.add_argument('--block-size', type=int, default=8, help='token block size') +# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks. +parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)') +parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)') +args = parser.parse_args() def main(): - args = parser.parse_args() - - # Create controllers. + # Create a controller for each node. controllers: List[Controller] = [] for i in range(args.num_nodes): controller = Controller( @@ -26,12 +27,18 @@ def main(): block_size=args.block_size, num_gpu_blocks=args.num_gpu_blocks, num_cpu_blocks=args.num_cpu_blocks, - dtype='float', ) controllers.append(controller) + # Create a frontend. + frontend = Frontend( + model_name=args.model, + block_size=args.block_size, + ) + # Create a scheduler. scheduler = Scheduler( + frontend=frontend, controllers=controllers, block_size=args.block_size, num_gpu_blocks=args.num_gpu_blocks, @@ -42,65 +49,19 @@ def main(): controllers[i].set_next(controllers[i + 1]) controllers[-1].set_next(scheduler) - # seq_groups, max_num_steps, stop_token_ids = generate_inputs(1000, args.block_size) - seq_groups, max_num_steps, stop_token_ids = test_inputs(args.block_size) - scheduler.pending.extend(seq_groups) - scheduler.max_num_steps.update(max_num_steps) - scheduler.stop_token_ids.update(stop_token_ids) + test_inputs = [ + 'Ion Stoica is a', + 'UC Berkeley is', + 'The future of cloud computing is', + ] + for prompt in test_inputs: + frontend.query(prompt) - while scheduler.pending or scheduler.running: - scheduler.prepare() + # FIXME + while True: scheduler.step() - - -def test_inputs(block_size): - from cacheflow.sequence import Sequence - from cacheflow.sequence import SequenceGroup - from cacheflow.utils import Counter - - from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m') - prompt = "Hello, I'm am conscious and" - prompt_tokens = tokenizer.encode(prompt) - - seq = Sequence(0, prompt_tokens, block_size=block_size) - seq_group = SequenceGroup(0, [seq]) - seq_groups = [seq_group] - max_num_steps = {0: 8} - stop_token_ids = {0: []} - return seq_groups, max_num_steps, stop_token_ids - - -def generate_inputs(num_inputs, block_size): - import random - random.seed(0) - - from cacheflow.sequence import Sequence - from cacheflow.sequence import SequenceGroup - from cacheflow.utils import Counter - - seq_group_counter = Counter() - seq_counter = Counter() - - max_num_steps = {} - stop_token_ids = {} - seq_groups = [] - for _ in range(num_inputs): - seq_group_id = next(seq_group_counter) - - prompt_len = random.randint(16, 128) - max_num_steps[seq_group_id] = random.randint(32, 1024) - stop_token_ids[seq_group_id] = [] - - seqs = [] - for _ in range(2): - seq_id = next(seq_counter) - seq = Sequence(seq_id, [0] * prompt_len, block_size=block_size) - seqs.append(seq) - seq_group = SequenceGroup(seq_group_id, seqs) - seq_groups.append(seq_group) - - return seq_groups, max_num_steps, stop_token_ids + if not scheduler.pending and not scheduler.running: + break if __name__ == '__main__':