[core] [2/N] refactor worker_base input preparation for multi-step (#7387)

This commit is contained in:
William Lin 2024-08-11 08:50:08 -07:00 committed by GitHub
parent 4fb7b52a2c
commit c08e2b3086
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 31 deletions

View File

@ -264,6 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
def prepare_worker_input(
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
virtual_engine = execute_model_req.virtual_engine
num_steps = execute_model_req.num_steps
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
@ -286,6 +287,7 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
)
@torch.inference_mode()

View File

@ -129,6 +129,7 @@ class WorkerInput:
blocks_to_swap_out: Optional[torch.Tensor] = None
blocks_to_copy: Optional[torch.Tensor] = None
virtual_engine: int = 0
num_steps: int = 1
@classmethod
def from_broadcasted_tensor_dict(
@ -145,6 +146,7 @@ class WorkerInput:
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
)
def as_broadcastable_tensor_dict(
@ -158,6 +160,7 @@ class WorkerInput:
"blocks_to_swap_out": self.blocks_to_swap_out,
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
}
return tensor_dict
@ -216,13 +219,50 @@ class LocalOrDistributedWorkerBase(WorkerBase):
"""
raise NotImplementedError
def execute_model(
def _get_worker_input_from_broadcast(
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
""" Get the worker input from the broadcasted tensor dict. """
assert self.do_metadata_broadcast
assert not self.is_driver_worker
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
return model_input, worker_input
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
""" Get the driver input and broadcast it to other workers. """
assert self.is_driver_worker
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
broadcast_tensor_dict(broadcast_data, src=0)
return model_input, worker_input
def prepare_input(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
"""
Prepare the inputs to ModelRunner and workers.
"""
if self.is_driver_worker:
if execute_model_req is None:
if self.do_metadata_broadcast:
@ -233,34 +273,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
# notify all other workers to stop their execution loop.
broadcast_tensor_dict({}, src=0)
return None
worker_input: WorkerInput = self.prepare_worker_input(
execute_model_req=execute_model_req)
model_input: ModelRunnerInputBase = (
self.model_runner.prepare_model_input(
execute_model_req.seq_group_metadata_list,
execute_model_req.virtual_engine,
execute_model_req.finished_requests_ids))
num_steps = execute_model_req.num_steps
if self.do_metadata_broadcast:
broadcast_data = worker_input.as_broadcastable_tensor_dict()
broadcast_data.update(
model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0)
return self._get_driver_input_and_broadcast(execute_model_req)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
if not broadcast_data:
return None
return self._get_worker_input_from_broadcast()
num_steps = broadcast_data.pop("num_steps")
worker_input = WorkerInput.from_broadcasted_tensor_dict(
broadcast_data)
model_input = (
self.model_runner.
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> Optional[List[SamplerOutput]]:
"""Executes at least one model step on the given sequences, unless no
sequences are provided."""
start_time = time.perf_counter()
inputs = self.prepare_input(execute_model_req)
if inputs is None:
return None
model_input, worker_input = inputs
num_steps = worker_input.num_steps
self.execute_worker(worker_input)