match tp+pp loss
This commit is contained in:
parent
63307c79a1
commit
51b5683dd3
@ -44,8 +44,8 @@ def train_step_pipeline_afab(model, data_loader, tensor_shapes, device):
|
||||
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
||||
pipeline_communicate(operation='send_forward', tensor=output_tensor, device=device, dtype=dtype)
|
||||
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
# calculate loss on the last stage
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||
|
||||
@ -70,8 +70,9 @@ def train_step_pipeline_1f1b(model, data_loader, tensor_shapes, device):
|
||||
batch = next(data_loader)
|
||||
batch["hidden_states"] = input_tensor.to(device) if input_tensor is not None else input_tensor
|
||||
output_tensor = model.forward(input_ids=batch["input_ids"].to(device), position_ids=batch["position_ids"].to(device), hidden_states=batch["hidden_states"])
|
||||
# Don't need to keep track of the loss on every rank. Just choosing a single rank (TP rank 0 in the last PP stage) is enough
|
||||
if pgm.process_group_manager.pp_is_last_stage and pgm.process_group_manager.global_rank == pgm.process_group_manager.tp_first_rank:
|
||||
|
||||
# calculate loss on the last stage
|
||||
if pgm.process_group_manager.pp_is_last_stage:
|
||||
output_tensor = F.cross_entropy(output_tensor.transpose(1, 2), batch["target_ids"].to(device), reduction='mean')
|
||||
nonlocal logging_loss
|
||||
logging_loss += output_tensor.item() / data_loader.num_local_micro_batches
|
||||
|
||||
Loading…
Reference in New Issue
Block a user