[core] [2/N] refactor worker_base input preparation for multi-step (#7387)
This commit is contained in:
parent
4fb7b52a2c
commit
c08e2b3086
@ -264,6 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_steps = execute_model_req.num_steps
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
@ -286,6 +287,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_steps=num_steps,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -129,6 +129,7 @@ class WorkerInput:
|
||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||
blocks_to_copy: Optional[torch.Tensor] = None
|
||||
virtual_engine: int = 0
|
||||
num_steps: int = 1
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
@ -145,6 +146,7 @@ class WorkerInput:
|
||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
virtual_engine=tensor_dict["virtual_engine"],
|
||||
num_steps=tensor_dict.pop("num_steps"),
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
@ -158,6 +160,7 @@ class WorkerInput:
|
||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"num_steps": self.num_steps,
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
@ -216,13 +219,50 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_model(
|
||||
def _get_worker_input_from_broadcast(
|
||||
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||
""" Get the worker input from the broadcasted tensor dict. """
|
||||
assert self.do_metadata_broadcast
|
||||
assert not self.is_driver_worker
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
return model_input, worker_input
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
|
||||
""" Get the driver input and broadcast it to other workers. """
|
||||
assert self.is_driver_worker
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
return model_input, worker_input
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||
"""
|
||||
Prepare the inputs to ModelRunner and workers.
|
||||
"""
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
@ -233,34 +273,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
# notify all other workers to stop their execution loop.
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
model_input: ModelRunnerInputBase = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
num_steps = execute_model_req.num_steps
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_data["num_steps"] = num_steps
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||
else:
|
||||
assert self.do_metadata_broadcast
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
return self._get_worker_input_from_broadcast()
|
||||
|
||||
num_steps = broadcast_data.pop("num_steps")
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Executes at least one model step on the given sequences, unless no
|
||||
sequences are provided."""
|
||||
start_time = time.perf_counter()
|
||||
|
||||
inputs = self.prepare_input(execute_model_req)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
model_input, worker_input = inputs
|
||||
num_steps = worker_input.num_steps
|
||||
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user