disable grad sync in afab
This commit is contained in:
parent
47c00be8c7
commit
c7a3fb016a
@ -51,7 +51,10 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device, dtype):
|
||||
input_tensors.append(input_tensor)
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
for _ in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
for i in range(data_loader.num_local_micro_batches): # All backward passes
|
||||
requires_grad_sync = pgm.process_group_manager.cp_dp_world_size > 1
|
||||
if requires_grad_sync:
|
||||
model.require_backward_grad_sync = (i == data_loader.num_local_micro_batches - 1)
|
||||
output_tensor_grad = pipeline_communicate(operation='recv_backward', shapes=tensor_shapes, device=device, dtype=dtype)
|
||||
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
|
||||
input_tensor_grad = model.backward(input_tensor, output_tensor, output_tensor_grad)
|
||||
|
||||
3
train.py
3
train.py
@ -151,13 +151,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
model.to(dtype).to(device)
|
||||
|
||||
if pgm.process_group_manager.tp_world_size > 1:
|
||||
TensorParallel(model)
|
||||
|
||||
if pgm.process_group_manager.pp_world_size > 1:
|
||||
model = PipelineParallel(model, model_config)
|
||||
|
||||
model.to(dtype).to(device)
|
||||
|
||||
# Context parallel and Data parallel both need gradient synchronization
|
||||
if pgm.process_group_manager.cp_dp_world_size > 1:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user