[core] [2/N] refactor worker_base input preparation for multi-step (#7387)
This commit is contained in:
parent
4fb7b52a2c
commit
c08e2b3086
@ -264,6 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
def prepare_worker_input(
|
def prepare_worker_input(
|
||||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
virtual_engine = execute_model_req.virtual_engine
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
|
num_steps = execute_model_req.num_steps
|
||||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||||
# they contain parameters to launch cudamemcpyasync.
|
# they contain parameters to launch cudamemcpyasync.
|
||||||
@ -286,6 +287,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=blocks_to_copy,
|
blocks_to_copy=blocks_to_copy,
|
||||||
virtual_engine=virtual_engine,
|
virtual_engine=virtual_engine,
|
||||||
|
num_steps=num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|||||||
@ -129,6 +129,7 @@ class WorkerInput:
|
|||||||
blocks_to_swap_out: Optional[torch.Tensor] = None
|
blocks_to_swap_out: Optional[torch.Tensor] = None
|
||||||
blocks_to_copy: Optional[torch.Tensor] = None
|
blocks_to_copy: Optional[torch.Tensor] = None
|
||||||
virtual_engine: int = 0
|
virtual_engine: int = 0
|
||||||
|
num_steps: int = 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_broadcasted_tensor_dict(
|
def from_broadcasted_tensor_dict(
|
||||||
@ -145,6 +146,7 @@ class WorkerInput:
|
|||||||
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"),
|
||||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||||
virtual_engine=tensor_dict["virtual_engine"],
|
virtual_engine=tensor_dict["virtual_engine"],
|
||||||
|
num_steps=tensor_dict.pop("num_steps"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def as_broadcastable_tensor_dict(
|
def as_broadcastable_tensor_dict(
|
||||||
@ -158,6 +160,7 @@ class WorkerInput:
|
|||||||
"blocks_to_swap_out": self.blocks_to_swap_out,
|
"blocks_to_swap_out": self.blocks_to_swap_out,
|
||||||
"blocks_to_copy": self.blocks_to_copy,
|
"blocks_to_copy": self.blocks_to_copy,
|
||||||
"virtual_engine": self.virtual_engine,
|
"virtual_engine": self.virtual_engine,
|
||||||
|
"num_steps": self.num_steps,
|
||||||
}
|
}
|
||||||
|
|
||||||
return tensor_dict
|
return tensor_dict
|
||||||
@ -216,13 +219,50 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_model(
|
def _get_worker_input_from_broadcast(
|
||||||
|
self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||||
|
""" Get the worker input from the broadcasted tensor dict. """
|
||||||
|
assert self.do_metadata_broadcast
|
||||||
|
assert not self.is_driver_worker
|
||||||
|
broadcast_data = broadcast_tensor_dict(src=0)
|
||||||
|
if not broadcast_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||||
|
model_input = (
|
||||||
|
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
broadcast_data))
|
||||||
|
|
||||||
|
return model_input, worker_input
|
||||||
|
|
||||||
|
def _get_driver_input_and_broadcast(
|
||||||
|
self, execute_model_req: ExecuteModelRequest
|
||||||
|
) -> Tuple[ModelRunnerInputBase, WorkerInput]:
|
||||||
|
""" Get the driver input and broadcast it to other workers. """
|
||||||
|
assert self.is_driver_worker
|
||||||
|
|
||||||
|
worker_input: WorkerInput = self.prepare_worker_input(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
model_input: ModelRunnerInputBase = (
|
||||||
|
self.model_runner.prepare_model_input(
|
||||||
|
execute_model_req.seq_group_metadata_list,
|
||||||
|
execute_model_req.virtual_engine,
|
||||||
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
|
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
|
return model_input, worker_input
|
||||||
|
|
||||||
|
def prepare_input(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]:
|
||||||
"""Executes at least one model step on the given sequences, unless no
|
"""
|
||||||
sequences are provided."""
|
Prepare the inputs to ModelRunner and workers.
|
||||||
start_time = time.perf_counter()
|
"""
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
if execute_model_req is None:
|
if execute_model_req is None:
|
||||||
if self.do_metadata_broadcast:
|
if self.do_metadata_broadcast:
|
||||||
@ -233,34 +273,24 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
# notify all other workers to stop their execution loop.
|
# notify all other workers to stop their execution loop.
|
||||||
broadcast_tensor_dict({}, src=0)
|
broadcast_tensor_dict({}, src=0)
|
||||||
return None
|
return None
|
||||||
|
return self._get_driver_input_and_broadcast(execute_model_req)
|
||||||
worker_input: WorkerInput = self.prepare_worker_input(
|
|
||||||
execute_model_req=execute_model_req)
|
|
||||||
model_input: ModelRunnerInputBase = (
|
|
||||||
self.model_runner.prepare_model_input(
|
|
||||||
execute_model_req.seq_group_metadata_list,
|
|
||||||
execute_model_req.virtual_engine,
|
|
||||||
execute_model_req.finished_requests_ids))
|
|
||||||
num_steps = execute_model_req.num_steps
|
|
||||||
|
|
||||||
if self.do_metadata_broadcast:
|
|
||||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
|
||||||
broadcast_data.update(
|
|
||||||
model_input.as_broadcastable_tensor_dict())
|
|
||||||
broadcast_data["num_steps"] = num_steps
|
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
|
||||||
else:
|
else:
|
||||||
assert self.do_metadata_broadcast
|
return self._get_worker_input_from_broadcast()
|
||||||
broadcast_data = broadcast_tensor_dict(src=0)
|
|
||||||
if not broadcast_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
num_steps = broadcast_data.pop("num_steps")
|
def execute_model(
|
||||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
self,
|
||||||
broadcast_data)
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
model_input = (
|
) -> Optional[List[SamplerOutput]]:
|
||||||
self.model_runner.
|
"""Executes at least one model step on the given sequences, unless no
|
||||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
sequences are provided."""
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
inputs = self.prepare_input(execute_model_req)
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_input, worker_input = inputs
|
||||||
|
num_steps = worker_input.num_steps
|
||||||
|
|
||||||
self.execute_worker(worker_input)
|
self.execute_worker(worker_input)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user