[core] [2/N] refactor worker_base input preparation for multi-step (#7387)

This commit is contained in:
William Lin 2024-08-11 08:50:08 -07:00 committed by GitHub
parent 4fb7b52a2c
commit c08e2b3086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 31 deletions

View File

@ -264,6 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
def prepare_worker_input( def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput: self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list) num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync. # they contain parameters to launch cudamemcpyasync.
@ -286,6 +287,7 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy, blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
num_steps=num_steps,
) )
@torch.inference_mode() @torch.inference_mode()

View File

@ -129,6 +129,7 @@ class WorkerInput:
blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0 virtual_engine: int = 0
num_steps: int = 1
@classmethod @classmethod
def from_broadcasted_tensor_dict( def from_broadcasted_tensor_dict(
@ -145,6 +146,7 @@ class WorkerInput:
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"], virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
) )
def as_broadcastable_tensor_dict( def as_broadcastable_tensor_dict(
@ -158,6 +160,7 @@ class WorkerInput:
"blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy, "blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine, "virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
} }
return tensor_dict return tensor_dict
@ -216,13 +219,50 @@ class LocalOrDistributedWorkerBase(WorkerBase):
""" """
raise NotImplementedError raise NotImplementedError
def execute_model( def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
return model_input, worker_input
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
return model_input, worker_input
def prepare_input(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]: ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
"""Executes at least one model step on the given sequences, unless no """
sequences are provided.""" Prepare the inputs to ModelRunner and workers.
start_time = time.perf_counter() """
if self.is_driver_worker: if self.is_driver_worker:
if execute_model_req is None: if execute_model_req is None:
if self.do_metadata_broadcast: if self.do_metadata_broadcast:
@ -233,34 +273,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# notify all other workers to stop their execution loop. # notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0) broadcast_tensor_dict({}, src=0)
return None return None
return self._get_driver_input_and_broadcast(execute_model_req)
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0)
else: else:
assert self.do_metadata_broadcast return self._get_worker_input_from_broadcast()
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
num_steps = broadcast_data.pop("num_steps") def execute_model(
worker_input = WorkerInput.from_broadcasted_tensor_dict( self,
broadcast_data) execute_model_req: Optional[ExecuteModelRequest] = None
model_input = ( ) -> Optional[List[SamplerOutput]]:
self.model_runner. """Executes at least one model step on the given sequences, unless no
make_model_input_from_broadcasted_tensor_dict(broadcast_data)) sequences are provided."""
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None
model_input, worker_input = inputs
num_steps = worker_input.num_steps
self.execute_worker(worker_input) self.execute_worker(worker_input)