Clean up the server script
This commit is contained in:
parent
6aef2278f4
commit
fa16389a2e
89
server.py
89
server.py
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from cacheflow.master.frontend import Frontend
|
||||||
from cacheflow.master.scheduler import Scheduler
|
from cacheflow.master.scheduler import Scheduler
|
||||||
from cacheflow.worker.controller import Controller
|
from cacheflow.worker.controller import Controller
|
||||||
|
|
||||||
@ -8,15 +9,15 @@ parser = argparse.ArgumentParser(description='CacheFlow server')
|
|||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||||
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
||||||
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
||||||
parser.add_argument('--block-size', type=int, default=8, help='block size')
|
parser.add_argument('--block-size', type=int, default=8, help='token block size')
|
||||||
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks')
|
# TODO(woosuk): Add an analytical model to determine the maximum number of GPU/CPU blocks.
|
||||||
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks')
|
parser.add_argument('--num-gpu-blocks', type=int, default=1024, help='number of GPU blocks (per GPU)')
|
||||||
|
parser.add_argument('--num-cpu-blocks', type=int, default=256, help='number of CPU blocks (per GPU)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parser.parse_args()
|
# Create a controller for each node.
|
||||||
|
|
||||||
# Create controllers.
|
|
||||||
controllers: List[Controller] = []
|
controllers: List[Controller] = []
|
||||||
for i in range(args.num_nodes):
|
for i in range(args.num_nodes):
|
||||||
controller = Controller(
|
controller = Controller(
|
||||||
@ -26,12 +27,18 @@ def main():
|
|||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
num_cpu_blocks=args.num_cpu_blocks,
|
num_cpu_blocks=args.num_cpu_blocks,
|
||||||
dtype='float',
|
|
||||||
)
|
)
|
||||||
controllers.append(controller)
|
controllers.append(controller)
|
||||||
|
|
||||||
|
# Create a frontend.
|
||||||
|
frontend = Frontend(
|
||||||
|
model_name=args.model,
|
||||||
|
block_size=args.block_size,
|
||||||
|
)
|
||||||
|
|
||||||
# Create a scheduler.
|
# Create a scheduler.
|
||||||
scheduler = Scheduler(
|
scheduler = Scheduler(
|
||||||
|
frontend=frontend,
|
||||||
controllers=controllers,
|
controllers=controllers,
|
||||||
block_size=args.block_size,
|
block_size=args.block_size,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
@ -42,65 +49,19 @@ def main():
|
|||||||
controllers[i].set_next(controllers[i + 1])
|
controllers[i].set_next(controllers[i + 1])
|
||||||
controllers[-1].set_next(scheduler)
|
controllers[-1].set_next(scheduler)
|
||||||
|
|
||||||
# seq_groups, max_num_steps, stop_token_ids = generate_inputs(1000, args.block_size)
|
test_inputs = [
|
||||||
seq_groups, max_num_steps, stop_token_ids = test_inputs(args.block_size)
|
'Ion Stoica is a',
|
||||||
scheduler.pending.extend(seq_groups)
|
'UC Berkeley is',
|
||||||
scheduler.max_num_steps.update(max_num_steps)
|
'The future of cloud computing is',
|
||||||
scheduler.stop_token_ids.update(stop_token_ids)
|
]
|
||||||
|
for prompt in test_inputs:
|
||||||
|
frontend.query(prompt)
|
||||||
|
|
||||||
while scheduler.pending or scheduler.running:
|
# FIXME
|
||||||
scheduler.prepare()
|
while True:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
if not scheduler.pending and not scheduler.running:
|
||||||
|
break
|
||||||
def test_inputs(block_size):
|
|
||||||
from cacheflow.sequence import Sequence
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.utils import Counter
|
|
||||||
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-125m')
|
|
||||||
prompt = "Hello, I'm am conscious and"
|
|
||||||
prompt_tokens = tokenizer.encode(prompt)
|
|
||||||
|
|
||||||
seq = Sequence(0, prompt_tokens, block_size=block_size)
|
|
||||||
seq_group = SequenceGroup(0, [seq])
|
|
||||||
seq_groups = [seq_group]
|
|
||||||
max_num_steps = {0: 8}
|
|
||||||
stop_token_ids = {0: []}
|
|
||||||
return seq_groups, max_num_steps, stop_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
def generate_inputs(num_inputs, block_size):
|
|
||||||
import random
|
|
||||||
random.seed(0)
|
|
||||||
|
|
||||||
from cacheflow.sequence import Sequence
|
|
||||||
from cacheflow.sequence import SequenceGroup
|
|
||||||
from cacheflow.utils import Counter
|
|
||||||
|
|
||||||
seq_group_counter = Counter()
|
|
||||||
seq_counter = Counter()
|
|
||||||
|
|
||||||
max_num_steps = {}
|
|
||||||
stop_token_ids = {}
|
|
||||||
seq_groups = []
|
|
||||||
for _ in range(num_inputs):
|
|
||||||
seq_group_id = next(seq_group_counter)
|
|
||||||
|
|
||||||
prompt_len = random.randint(16, 128)
|
|
||||||
max_num_steps[seq_group_id] = random.randint(32, 1024)
|
|
||||||
stop_token_ids[seq_group_id] = []
|
|
||||||
|
|
||||||
seqs = []
|
|
||||||
for _ in range(2):
|
|
||||||
seq_id = next(seq_counter)
|
|
||||||
seq = Sequence(seq_id, [0] * prompt_len, block_size=block_size)
|
|
||||||
seqs.append(seq)
|
|
||||||
seq_group = SequenceGroup(seq_group_id, seqs)
|
|
||||||
seq_groups.append(seq_group)
|
|
||||||
|
|
||||||
return seq_groups, max_num_steps, stop_token_ids
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user